From 611e8c983f8f245729e9a77e8cfc6344dbec9a6c Mon Sep 17 00:00:00 2001 From: Sourav Basu Date: Mon, 1 Jun 2026 19:07:27 +0530 Subject: [PATCH 1/9] Changes for passkey implementation --- .../auth_schemes/__init__.py | 3 +- .../auth_schemes/dpop_auth.py | 51 + .../auth_server/my_account_client.py | 464 ++++++++- .../auth_server/server_client.py | 967 +++++++++++------- .../auth_types/__init__.py | 274 ++++- .../tests/test_dpop_auth.py | 145 +++ .../tests/test_passkey_my_account.py | 473 +++++++++ .../tests/test_passkey_server_client.py | 523 ++++++++++ 8 files changed, 2483 insertions(+), 417 deletions(-) create mode 100644 src/auth0_server_python/auth_schemes/dpop_auth.py create mode 100644 src/auth0_server_python/tests/test_dpop_auth.py create mode 100644 src/auth0_server_python/tests/test_passkey_my_account.py create mode 100644 src/auth0_server_python/tests/test_passkey_server_client.py diff --git a/src/auth0_server_python/auth_schemes/__init__.py b/src/auth0_server_python/auth_schemes/__init__.py index 1c2c869..ef37613 100644 --- a/src/auth0_server_python/auth_schemes/__init__.py +++ b/src/auth0_server_python/auth_schemes/__init__.py @@ -1,3 +1,4 @@ from .bearer_auth import BearerAuth +from .dpop_auth import DPoPAuth -__all__ = ["BearerAuth"] +__all__ = ["BearerAuth", "DPoPAuth"] diff --git a/src/auth0_server_python/auth_schemes/dpop_auth.py b/src/auth0_server_python/auth_schemes/dpop_auth.py new file mode 100644 index 0000000..1517a78 --- /dev/null +++ b/src/auth0_server_python/auth_schemes/dpop_auth.py @@ -0,0 +1,51 @@ +import base64 +import hashlib +import time +import uuid + +import httpx +from jwcrypto import jwk +from jwcrypto import jwt as jwcrypto_jwt + + +def _base64url(data: bytes) -> str: + return base64.urlsafe_b64encode(data).rstrip(b"=").decode("ascii") + + +class DPoPAuth(httpx.Auth): + def __init__(self, token: str, key: "jwk.JWK") -> None: + public_jwk = key.export_public(as_dict=True) + if public_jwk.get("kty") != "EC" or public_jwk.get("crv") != "P-256": + raise ValueError("DPoP key must be an EC P-256 key") + self._token = token + self._key = key + self._public_jwk = public_jwk + + def __repr__(self) -> str: + return "DPoPAuth(token=[REDACTED], key=[REDACTED])" + + def __str__(self) -> str: + return "DPoPAuth(token=[REDACTED], key=[REDACTED])" + + def auth_flow(self, request: httpx.Request): + proof = self._make_proof(request.method, str(request.url)) + request.headers["Authorization"] = f"DPoP {self._token}" + request.headers["DPoP"] = proof + yield request + + def _make_proof(self, method: str, url: str) -> str: + htu = url.split("?")[0].split("#")[0] + ath = _base64url(hashlib.sha256(self._token.encode("ascii")).digest()) + + header = {"typ": "dpop+jwt", "alg": "ES256", "jwk": self._public_jwk} + payload = { + "jti": str(uuid.uuid4()), + "htm": method.upper(), + "htu": htu, + "iat": int(time.time()), + "ath": ath, + } + + token = jwcrypto_jwt.JWT(header=header, claims=payload) + token.make_signed_token(self._key) + return token.serialize() diff --git a/src/auth0_server_python/auth_server/my_account_client.py b/src/auth0_server_python/auth_server/my_account_client.py index 499b981..a6aed8f 100644 --- a/src/auth0_server_python/auth_server/my_account_client.py +++ b/src/auth0_server_python/auth_server/my_account_client.py @@ -1,16 +1,26 @@ - -from typing import Optional +import json +from typing import TYPE_CHECKING, Optional +from urllib.parse import quote import httpx +from pydantic import ValidationError from auth0_server_python.auth_schemes.bearer_auth import BearerAuth +from auth0_server_python.auth_schemes.dpop_auth import DPoPAuth from auth0_server_python.auth_types import ( + AuthenticationMethod, CompleteConnectAccountRequest, CompleteConnectAccountResponse, ConnectAccountRequest, ConnectAccountResponse, + EnrollAuthenticationMethodRequest, + EnrollmentChallengeResponse, + GetFactorsResponse, + ListAuthenticationMethodsResponse, ListConnectedAccountConnectionsResponse, ListConnectedAccountsResponse, + UpdateAuthenticationMethodRequest, + VerifyAuthenticationMethodRequest, ) from auth0_server_python.error import ( ApiError, @@ -19,6 +29,18 @@ MyAccountApiError, ) +if TYPE_CHECKING: + from jwcrypto import jwk + + +def _make_auth( + access_token: str, + dpop_key: Optional["jwk.JWK"] = None, +) -> httpx.Auth: + if dpop_key is not None: + return DPoPAuth(access_token, dpop_key) + return BearerAuth(access_token) + class MyAccountClient: """ @@ -52,9 +74,7 @@ def audience(self): return f"https://{self._domain}/me/" async def connect_account( - self, - access_token: str, - request: ConnectAccountRequest + self, access_token: str, request: ConnectAccountRequest ) -> ConnectAccountResponse: """ Initiate the connected account flow. @@ -75,7 +95,7 @@ async def connect_account( response = await client.post( url=f"{self.audience}v1/connected-accounts/connect", json=request.model_dump(exclude_none=True), - auth=BearerAuth(access_token) + auth=BearerAuth(access_token), ) if response.status_code != 201: @@ -85,7 +105,7 @@ async def connect_account( type=error_data.get("type", None), detail=error_data.get("detail", None), status=error_data.get("status", None), - validation_errors=error_data.get("validation_errors", None) + validation_errors=error_data.get("validation_errors", None), ) data = response.json() @@ -98,13 +118,11 @@ async def connect_account( raise ApiError( "connect_account_error", f"Connected Accounts connect request failed: {str(e) or 'Unknown error'}", - e + e, ) async def complete_connect_account( - self, - access_token: str, - request: CompleteConnectAccountRequest + self, access_token: str, request: CompleteConnectAccountRequest ) -> CompleteConnectAccountResponse: """ Complete the connected account flow after user authorization. @@ -125,7 +143,7 @@ async def complete_connect_account( response = await client.post( url=f"{self.audience}v1/connected-accounts/complete", json=request.model_dump(exclude_none=True), - auth=BearerAuth(access_token) + auth=BearerAuth(access_token), ) if response.status_code != 201: @@ -135,7 +153,7 @@ async def complete_connect_account( type=error_data.get("type", None), detail=error_data.get("detail", None), status=error_data.get("status", None), - validation_errors=error_data.get("validation_errors", None) + validation_errors=error_data.get("validation_errors", None), ) data = response.json() @@ -148,7 +166,7 @@ async def complete_connect_account( raise ApiError( "connect_account_error", f"Connected Accounts complete request failed: {str(e) or 'Unknown error'}", - e + e, ) async def list_connected_accounts( @@ -156,7 +174,7 @@ async def list_connected_accounts( access_token: str, connection: Optional[str] = None, from_param: Optional[str] = None, - take: Optional[int] = None + take: Optional[int] = None, ) -> ListConnectedAccountsResponse: """ List connected accounts for the authenticated user. @@ -195,7 +213,7 @@ async def list_connected_accounts( response = await client.get( url=f"{self.audience}v1/connected-accounts/accounts", params=params, - auth=BearerAuth(access_token) + auth=BearerAuth(access_token), ) if response.status_code != 200: @@ -205,7 +223,7 @@ async def list_connected_accounts( type=error_data.get("type", None), detail=error_data.get("detail", None), status=error_data.get("status", None), - validation_errors=error_data.get("validation_errors", None) + validation_errors=error_data.get("validation_errors", None), ) data = response.json() @@ -218,15 +236,10 @@ async def list_connected_accounts( raise ApiError( "connect_account_error", f"Connected Accounts list request failed: {str(e) or 'Unknown error'}", - e + e, ) - - async def delete_connected_account( - self, - access_token: str, - connected_account_id: str - ) -> None: + async def delete_connected_account(self, access_token: str, connected_account_id: str) -> None: """ Delete a connected account for the authenticated user. @@ -253,7 +266,7 @@ async def delete_connected_account( async with self._get_http_client() as client: response = await client.delete( url=f"{self.audience}v1/connected-accounts/accounts/{connected_account_id}", - auth=BearerAuth(access_token) + auth=BearerAuth(access_token), ) if response.status_code != 204: @@ -263,7 +276,7 @@ async def delete_connected_account( type=error_data.get("type", None), detail=error_data.get("detail", None), status=error_data.get("status", None), - validation_errors=error_data.get("validation_errors", None) + validation_errors=error_data.get("validation_errors", None), ) except Exception as e: @@ -272,14 +285,11 @@ async def delete_connected_account( raise ApiError( "connect_account_error", f"Connected Accounts delete request failed: {str(e) or 'Unknown error'}", - e + e, ) async def list_connected_account_connections( - self, - access_token: str, - from_param: Optional[str] = None, - take: Optional[int] = None + self, access_token: str, from_param: Optional[str] = None, take: Optional[int] = None ) -> ListConnectedAccountConnectionsResponse: """ List available connections that support connected accounts. @@ -315,7 +325,7 @@ async def list_connected_account_connections( response = await client.get( url=f"{self.audience}v1/connected-accounts/connections", params=params, - auth=BearerAuth(access_token) + auth=BearerAuth(access_token), ) if response.status_code != 200: @@ -325,7 +335,7 @@ async def list_connected_account_connections( type=error_data.get("type", None), detail=error_data.get("detail", None), status=error_data.get("status", None), - validation_errors=error_data.get("validation_errors", None) + validation_errors=error_data.get("validation_errors", None), ) data = response.json() @@ -338,5 +348,391 @@ async def list_connected_account_connections( raise ApiError( "connect_account_error", f"Connected Accounts list connections request failed: {str(e) or 'Unknown error'}", - e + e, + ) + + async def get_factors( + self, + access_token: str, + dpop_key: Optional["jwk.JWK"] = None, + ) -> GetFactorsResponse: + """ + Retrieve the list of factors available for enrollment. + + Args: + access_token: User's access token (scope: read:me:factors). + dpop_key: Optional EC P-256 key for DPoP-bound token presentation. + + Returns: + GetFactorsResponse containing the available factors. + + Raises: + MissingRequiredArgumentError: If access_token is not provided. + MyAccountApiError: If the API returns an error response. + ApiError: If the request fails due to network or other issues. + """ + if not access_token: + raise MissingRequiredArgumentError("access_token") + + try: + async with self._get_http_client() as client: + response = await client.get( + url=f"{self.audience}v1/factors", + auth=_make_auth(access_token, dpop_key), + ) + + if response.status_code != 200: + try: + error_data = response.json() + except (json.JSONDecodeError, ValueError): + raise ApiError( + "get_factors_error", + f"Get factors failed with status {response.status_code}", + ) + raise MyAccountApiError( + title=error_data.get("title", None), + type=error_data.get("type", None), + detail=error_data.get("detail", None), + status=error_data.get("status", None), + validation_errors=error_data.get("validation_errors", None), + ) + + return GetFactorsResponse.model_validate(response.json()) + + except Exception as e: + if isinstance(e, (MyAccountApiError, ApiError, ValidationError)): + raise + raise ApiError( + "get_factors_error", + "Get factors request failed", + e, + ) + + async def list_authentication_methods( + self, + access_token: str, + type_filter: Optional[str] = None, + dpop_key: Optional["jwk.JWK"] = None, + ) -> ListAuthenticationMethodsResponse: + if not access_token: + raise MissingRequiredArgumentError("access_token") + + try: + async with self._get_http_client() as client: + params = {} + if type_filter: + params["type"] = type_filter + + response = await client.get( + url=f"{self.audience}v1/authentication-methods", + params=params, + auth=_make_auth(access_token, dpop_key), + ) + + if response.status_code != 200: + try: + error_data = response.json() + except (json.JSONDecodeError, ValueError): + raise ApiError( + "list_authentication_methods_error", + f"List authentication methods failed with status {response.status_code}", + ) + raise MyAccountApiError( + title=error_data.get("title", None), + type=error_data.get("type", None), + detail=error_data.get("detail", None), + status=error_data.get("status", None), + validation_errors=error_data.get("validation_errors", None), + ) + + return ListAuthenticationMethodsResponse.model_validate(response.json()) + + except Exception as e: + if isinstance(e, (MyAccountApiError, ApiError, ValidationError)): + raise + raise ApiError( + "list_authentication_methods_error", + "List authentication methods request failed", + e, + ) + + async def get_authentication_method( + self, + access_token: str, + authentication_method_id: str, + dpop_key: Optional["jwk.JWK"] = None, + ) -> AuthenticationMethod: + if not access_token: + raise MissingRequiredArgumentError("access_token") + if not authentication_method_id: + raise MissingRequiredArgumentError("authentication_method_id") + + try: + async with self._get_http_client() as client: + response = await client.get( + url=f"{self.audience}v1/authentication-methods/{quote(authentication_method_id, safe='')}", + auth=_make_auth(access_token, dpop_key), + ) + + if response.status_code != 200: + try: + error_data = response.json() + except (json.JSONDecodeError, ValueError): + raise ApiError( + "get_authentication_method_error", + f"Get authentication method failed with status {response.status_code}", + ) + raise MyAccountApiError( + title=error_data.get("title", None), + type=error_data.get("type", None), + detail=error_data.get("detail", None), + status=error_data.get("status", None), + validation_errors=error_data.get("validation_errors", None), + ) + + return AuthenticationMethod.model_validate(response.json()) + + except Exception as e: + if isinstance(e, (MyAccountApiError, ApiError, ValidationError)): + raise + raise ApiError( + "get_authentication_method_error", + "Get authentication method request failed", + e, + ) + + async def delete_authentication_method( + self, + access_token: str, + authentication_method_id: str, + dpop_key: Optional["jwk.JWK"] = None, + ) -> None: + if not access_token: + raise MissingRequiredArgumentError("access_token") + if not authentication_method_id: + raise MissingRequiredArgumentError("authentication_method_id") + + try: + async with self._get_http_client() as client: + response = await client.delete( + url=f"{self.audience}v1/authentication-methods/{quote(authentication_method_id, safe='')}", + auth=_make_auth(access_token, dpop_key), + ) + + if response.status_code != 204: + try: + error_data = response.json() + except (json.JSONDecodeError, ValueError): + raise ApiError( + "delete_authentication_method_error", + f"Delete authentication method failed with status {response.status_code}", + ) + raise MyAccountApiError( + title=error_data.get("title", None), + type=error_data.get("type", None), + detail=error_data.get("detail", None), + status=error_data.get("status", None), + validation_errors=error_data.get("validation_errors", None), + ) + + except Exception as e: + if isinstance(e, (MyAccountApiError, ApiError, ValidationError)): + raise + raise ApiError( + "delete_authentication_method_error", + "Delete authentication method request failed", + e, + ) + + async def update_authentication_method( + self, + access_token: str, + authentication_method_id: str, + request: UpdateAuthenticationMethodRequest, + dpop_key: Optional["jwk.JWK"] = None, + ) -> AuthenticationMethod: + if not access_token: + raise MissingRequiredArgumentError("access_token") + if not authentication_method_id: + raise MissingRequiredArgumentError("authentication_method_id") + if request is None: + raise MissingRequiredArgumentError("request") + + try: + async with self._get_http_client() as client: + response = await client.patch( + url=f"{self.audience}v1/authentication-methods/{quote(authentication_method_id, safe='')}", + json=request.model_dump(exclude_none=True), + auth=_make_auth(access_token, dpop_key), + ) + + if response.status_code != 200: + try: + error_data = response.json() + except (json.JSONDecodeError, ValueError): + raise ApiError( + "update_authentication_method_error", + f"Update authentication method failed with status {response.status_code}", + ) + raise MyAccountApiError( + title=error_data.get("title", None), + type=error_data.get("type", None), + detail=error_data.get("detail", None), + status=error_data.get("status", None), + validation_errors=error_data.get("validation_errors", None), + ) + + return AuthenticationMethod.model_validate(response.json()) + + except Exception as e: + if isinstance(e, (MyAccountApiError, ApiError, ValidationError)): + raise + raise ApiError( + "update_authentication_method_error", + "Update authentication method request failed", + e, + ) + + async def enroll_authentication_method( + self, + access_token: str, + request: EnrollAuthenticationMethodRequest, + dpop_key: Optional["jwk.JWK"] = None, + ) -> EnrollmentChallengeResponse: + """Step 1 of 2: Start enrollment (POST /me/v1/authentication-methods). + + For passkey enrollment, pass the returned authn_params_public_key to + navigator.credentials.create(), then call verify_authentication_method() + with the auth_session and credential result. + + Requires scope: create:me:authentication_methods + """ + if not access_token: + raise MissingRequiredArgumentError("access_token") + if request is None: + raise MissingRequiredArgumentError("request") + + try: + async with self._get_http_client() as client: + response = await client.post( + url=f"{self.audience}v1/authentication-methods", + json=request.model_dump(exclude_none=True), + auth=_make_auth(access_token, dpop_key), + ) + + if response.status_code != 201: + try: + error_data = response.json() + except (json.JSONDecodeError, ValueError): + raise ApiError( + "enroll_authentication_method_error", + f"Enroll authentication method failed with status {response.status_code}", + ) + raise MyAccountApiError( + title=error_data.get("title", None), + type=error_data.get("type", None), + detail=error_data.get("detail", None), + status=error_data.get("status", None), + validation_errors=error_data.get("validation_errors", None), + ) + + location = response.headers.get("location") + if not location: + raise ApiError( + "enroll_authentication_method_error", + "Enrollment succeeded (201) but Location header is missing", + ) + + authentication_method_id = ( + location.split("?")[0].split("#")[0].rstrip("/").split("/")[-1] + ) + if not authentication_method_id: + raise ApiError( + "enroll_authentication_method_error", + "Enrollment succeeded (201) but could not extract ID from Location header", + ) + + try: + data = response.json() + except (json.JSONDecodeError, ValueError): + raise ApiError( + "enroll_authentication_method_error", + "Enrollment succeeded (201) but response body is not valid JSON", + ) + + auth_session = data.get("auth_session") + if not auth_session: + raise ApiError( + "enroll_authentication_method_error", + "Enrollment succeeded (201) but auth_session is missing from response", + ) + + return EnrollmentChallengeResponse.model_validate( + { + "authentication_method_id": authentication_method_id, + "auth_session": auth_session, + "authn_params_public_key": data.get("authn_params_public_key"), + } + ) + + except Exception as e: + if isinstance(e, (MyAccountApiError, ApiError, ValidationError)): + raise + raise ApiError( + "enroll_authentication_method_error", + "Enroll authentication method request failed", + e, + ) + + async def verify_authentication_method( + self, + access_token: str, + authentication_method_id: str, + request: VerifyAuthenticationMethodRequest, + dpop_key: Optional["jwk.JWK"] = None, + ) -> AuthenticationMethod: + """Step 2 of 2: Verify enrollment (POST /me/v1/authentication-methods/{id}/verify). + + Requires scope: create:me:authentication_methods + """ + if not access_token: + raise MissingRequiredArgumentError("access_token") + if not authentication_method_id: + raise MissingRequiredArgumentError("authentication_method_id") + if request is None: + raise MissingRequiredArgumentError("request") + + try: + async with self._get_http_client() as client: + response = await client.post( + url=f"{self.audience}v1/authentication-methods/{quote(authentication_method_id, safe='')}/verify", + json=request.model_dump(by_alias=True, exclude_none=True), + auth=_make_auth(access_token, dpop_key), + ) + + if response.status_code != 200: + try: + error_data = response.json() + except (json.JSONDecodeError, ValueError): + raise ApiError( + "verify_authentication_method_error", + f"Verify authentication method failed with status {response.status_code}", + ) + raise MyAccountApiError( + title=error_data.get("title", None), + type=error_data.get("type", None), + detail=error_data.get("detail", None), + status=error_data.get("status", None), + validation_errors=error_data.get("validation_errors", None), + ) + + return AuthenticationMethod.model_validate(response.json()) + + except Exception as e: + if isinstance(e, (MyAccountApiError, ApiError, ValidationError)): + raise + raise ApiError( + "verify_authentication_method_error", + "Verify authentication method request failed", + e, ) diff --git a/src/auth0_server_python/auth_server/server_client.py b/src/auth0_server_python/auth_server/server_client.py index 91de45d..334eb00 100644 --- a/src/auth0_server_python/auth_server/server_client.py +++ b/src/auth0_server_python/auth_server/server_client.py @@ -31,6 +31,10 @@ LogoutOptions, LogoutTokenClaims, MfaRequirements, + PasskeyAuthResponse, + PasskeyLoginChallengeResponse, + PasskeySignupChallengeResponse, + PasskeyTokenResponse, StartInteractiveLoginOptions, StateData, TokenExchangeResponse, @@ -65,11 +69,18 @@ ) # Generic type for store options -TStoreOptions = TypeVar('TStoreOptions') +TStoreOptions = TypeVar("TStoreOptions") # redirect_uri is intentionally excluded — in MCD mode it is built # dynamically from the resolved domain at login time. -INTERNAL_AUTHORIZE_PARAMS = ["client_id", "response_type", - "code_challenge", "code_challenge_method", "state", "nonce", "scope"] +INTERNAL_AUTHORIZE_PARAMS = [ + "client_id", + "response_type", + "code_challenge", + "code_challenge_method", + "state", + "nonce", + "scope", +] class ServerClient(Generic[TStoreOptions]): @@ -77,6 +88,7 @@ class ServerClient(Generic[TStoreOptions]): Main client for Auth0 server SDK. Handles authentication flows, session management, and token operations using Authlib for OIDC functionality. """ + DEFAULT_AUDIENCE_STATE_KEY = "default" # ============================================================================ @@ -117,9 +129,7 @@ def __init__( raise MissingRequiredArgumentError("secret") if domain is None: - raise ConfigurationError( - "Domain is required" - ) + raise ConfigurationError("Domain is required") # Validate domain type if not isinstance(domain, str) and not callable(domain): @@ -164,14 +174,12 @@ def __init__( headers=self._telemetry_headers, ) - self._my_account_client = MyAccountClient( - domain=domain, headers=self._telemetry_headers - ) + self._my_account_client = MyAccountClient(domain=domain, headers=self._telemetry_headers) # Unified cache for OIDC metadata and JWKS per domain (LRU eviction + TTL) self._discovery_cache: OrderedDict[str, dict] = OrderedDict() - self._cache_ttl = 600 # 10 mins. TTL - self._cache_max_entries = 100 # Max 100 domains + self._cache_ttl = 600 # 10 mins. TTL + self._cache_max_entries = 100 # Max 100 domains # Initialize MFA client self._mfa_client = MfaClient( @@ -198,14 +206,14 @@ def _normalize_url(self, value: str) -> str: return value value = value.lower() - if value.startswith('https://'): + if value.startswith("https://"): pass - elif value.startswith('http://'): - value = value.replace('http://', 'https://') + elif value.startswith("http://"): + value = value.replace("http://", "https://") else: - value = f'https://{value}' + value = f"https://{value}" - return value.rstrip('/') + return value.rstrip("/") async def _resolve_current_domain(self, store_options=None) -> str: """Resolve domain from resolver function or return static domain.""" @@ -218,8 +226,7 @@ async def _resolve_current_domain(self, store_options=None) -> str: raise except Exception as e: raise DomainResolverError( - f"Domain resolver function raised an exception: {str(e)}", - original_error=e + f"Domain resolver function raised an exception: {str(e)}", original_error=e ) return self._domain @@ -233,18 +240,18 @@ def _get_session_domain(self, state_data_dict: dict) -> Optional[str]: 2. self._domain — static domain (if configured) 3. Extract hostname from user.iss — derive from user's issuer claim """ - domain = state_data_dict.get('domain') + domain = state_data_dict.get("domain") if domain: return domain if self._domain: return self._domain - user = state_data_dict.get('user') + user = state_data_dict.get("user") if isinstance(user, dict): - iss = user.get('iss') + iss = user.get("iss") else: - iss = getattr(user, 'iss', None) if user else None + iss = getattr(user, "iss", None) if user else None if iss: parsed = urlparse(iss) @@ -347,7 +354,7 @@ async def _get_oidc_metadata_cached(self, domain: str) -> dict: self._discovery_cache[domain] = { "metadata": metadata, "jwks": None, - "expires_at": now + self._cache_ttl + "expires_at": now + self._cache_ttl, } return metadata @@ -409,11 +416,11 @@ async def _get_jwks_cached(self, domain: str, metadata: dict = None) -> dict: if not metadata: metadata = await self._get_oidc_metadata_cached(domain) - jwks_uri = metadata.get('jwks_uri') + jwks_uri = metadata.get("jwks_uri") if not jwks_uri: raise ApiError( "missing_jwks_uri", - f"OIDC metadata for {domain} does not contain jwks_uri. Provider may be non-RFC-compliant." + f"OIDC metadata for {domain} does not contain jwks_uri. Provider may be non-RFC-compliant.", ) # Fetch JWKS @@ -430,7 +437,7 @@ async def _get_jwks_cached(self, domain: str, metadata: dict = None) -> dict: self._discovery_cache[domain] = { "metadata": metadata, "jwks": jwks, - "expires_at": now + self._cache_ttl + "expires_at": now + self._cache_ttl, } return jwks @@ -442,9 +449,7 @@ async def _get_jwks_cached(self, domain: str, metadata: dict = None) -> dict: # ============================================================================ async def start_interactive_login( - self, - options: Optional[StartInteractiveLoginOptions] = None, - store_options: dict = None + self, options: Optional[StartInteractiveLoginOptions] = None, store_options: dict = None ) -> str: """ Starts the interactive login process and returns a URL to redirect to. @@ -465,15 +470,17 @@ async def start_interactive_login( try: metadata = await self._get_oidc_metadata_cached(origin_domain) except Exception as e: - raise ApiError("metadata_error", - "Failed to fetch OIDC metadata", e) + raise ApiError("metadata_error", "Failed to fetch OIDC metadata", e) # Get effective authorization params (merge defaults with provided ones) auth_params = dict(self._default_authorization_params) if options.authorization_params: auth_params.update( - {k: v for k, v in options.authorization_params.items( - ) if k not in INTERNAL_AUTHORIZE_PARAMS} + { + k: v + for k, v in options.authorization_params.items() + if k not in INTERNAL_AUTHORIZE_PARAMS + } ) # Ensure we have a redirect_uri @@ -497,7 +504,11 @@ async def start_interactive_login( auth_params["state"] = state # Merge any requested scope with defaults - requested_scope = options.authorization_params.get("scope", None) if options.authorization_params else None + requested_scope = ( + options.authorization_params.get("scope", None) + if options.authorization_params + else None + ) audience = auth_params.get("audience", None) merged_scope = self._merge_scope_with_defaults(requested_scope, audience) auth_params["scope"] = merged_scope @@ -513,65 +524,61 @@ async def start_interactive_login( # Store the transaction data await self._transaction_store.set( - f"{self._transaction_identifier}:{state}", - transaction_data, - options=store_options + f"{self._transaction_identifier}:{state}", transaction_data, options=store_options ) # Set metadata for OAuth client self._oauth.metadata = metadata # If PAR is enabled, use the PAR endpoint if self._pushed_authorization_requests: - par_endpoint = self._oauth.metadata.get( - "pushed_authorization_request_endpoint") + par_endpoint = self._oauth.metadata.get("pushed_authorization_request_endpoint") if not par_endpoint: raise ApiError( - "configuration_error", "PAR is enabled but pushed_authorization_request_endpoint is missing in metadata") + "configuration_error", + "PAR is enabled but pushed_authorization_request_endpoint is missing in metadata", + ) auth_params["client_id"] = self._client_id # Post the auth_params to the PAR endpoint async with self._get_http_client() as client: par_response = await client.post( - par_endpoint, - data=auth_params, - auth=(self._client_id, self._client_secret) + par_endpoint, data=auth_params, auth=(self._client_id, self._client_secret) ) if par_response.status_code not in (200, 201): error_data = par_response.json() raise ApiError( error_data.get("error", "par_error"), error_data.get( - "error_description", "Failed to obtain request_uri from PAR endpoint") + "error_description", "Failed to obtain request_uri from PAR endpoint" + ), ) par_data = par_response.json() request_uri = par_data.get("request_uri") if not request_uri: - raise ApiError( - "par_error", "No request_uri returned from PAR endpoint") + raise ApiError("par_error", "No request_uri returned from PAR endpoint") auth_endpoint = self._oauth.metadata.get("authorization_endpoint") final_url = f"{auth_endpoint}?request_uri={request_uri}&response_type={auth_params['response_type']}&client_id={self._client_id}" return final_url else: if "authorization_endpoint" not in self._oauth.metadata: - raise ApiError("configuration_error", - "Authorization endpoint missing in OIDC metadata") + raise ApiError( + "configuration_error", "Authorization endpoint missing in OIDC metadata" + ) authorization_endpoint = self._oauth.metadata["authorization_endpoint"] try: auth_url, state = self._oauth.create_authorization_url( - authorization_endpoint, **auth_params) + authorization_endpoint, **auth_params + ) except Exception as e: - raise ApiError("authorization_url_error", - "Failed to create authorization URL", e) + raise ApiError("authorization_url_error", "Failed to create authorization URL", e) return auth_url async def complete_interactive_login( - self, - url: str, - store_options: dict = None + self, url: str, store_options: dict = None ) -> dict[str, Any]: """ Completes the login process after user is redirected back. @@ -594,7 +601,9 @@ async def complete_interactive_login( # Retrieve the transaction data using the state transaction_identifier = f"{self._transaction_identifier}:{state}" - transaction_data = await self._transaction_store.get(transaction_identifier, options=store_options) + transaction_data = await self._transaction_store.get( + transaction_identifier, options=store_options + ) if not transaction_data: raise MissingTransactionError() @@ -615,7 +624,7 @@ async def complete_interactive_login( # Fetch metadata and derive issuer from the origin domain metadata = await self._get_oidc_metadata_cached(origin_domain) - origin_issuer = metadata.get('issuer') + origin_issuer = metadata.get("issuer") self._oauth.metadata = metadata # Exchange the code for tokens @@ -631,8 +640,7 @@ async def complete_interactive_login( ) except OAuthError as e: # Raise a custom error (or handle it as appropriate) - raise ApiError( - "token_error", f"Token exchange failed: {str(e)}", e) + raise ApiError("token_error", f"Token exchange failed: {str(e)}", e) # Use the userinfo field from the token_response for user claims user_info = token_response.get("userinfo") @@ -647,14 +655,14 @@ async def complete_interactive_login( # Decode and verify ID token with signature verification enabled try: - claims = await self._verify_and_decode_jwt( - id_token, jwks, audience=self._client_id - ) + claims = await self._verify_and_decode_jwt(id_token, jwks, audience=self._client_id) # Custom normalized issuer validation token_issuer = claims.get("iss", "") if self._normalize_url(token_issuer) != self._normalize_url(origin_issuer): - raise IssuerValidationError("ID token issuer mismatch. Ensure your Auth0 domain is configured correctly.") + raise IssuerValidationError( + "ID token issuer mismatch. Ensure your Auth0 domain is configured correctly." + ) user_claims = UserClaims.parse_obj(claims) except ValueError as e: @@ -663,40 +671,33 @@ async def complete_interactive_login( raise ApiError( "invalid_signature", f"ID token signature verification failed. The token may have been tampered with or is from an untrusted source: {str(e)}", - e + e, ) except jwt.InvalidAudienceError as e: raise ApiError( "invalid_audience", f"ID token audience mismatch. Expected: {self._client_id}. Ensure your client_id is configured correctly: {str(e)}", - e + e, ) except jwt.ExpiredSignatureError as e: - raise ApiError( - "token_expired", - f"ID token has expired: {str(e)}", - e - ) + raise ApiError("token_expired", f"ID token has expired: {str(e)}", e) except jwt.InvalidTokenError as e: - raise ApiError( - "invalid_token", - f"ID token verification failed: {str(e)}", - e - ) - + raise ApiError("invalid_token", f"ID token verification failed: {str(e)}", e) # Build a token set using the token response data token_set = TokenSet( audience=transaction_data.audience or self.DEFAULT_AUDIENCE_STATE_KEY, access_token=token_response.get("access_token", ""), scope=token_response.get("scope", ""), - expires_at=int(time.time()) + - token_response.get("expires_in", 3600) + expires_at=int(time.time()) + token_response.get("expires_in", 3600), ) # Generate a session id (sid) from token_response or transaction data, or create a new one - sid = user_info.get( - "sid") if user_info and "sid" in user_info else PKCE.generate_random_string(32) + sid = ( + user_info.get("sid") + if user_info and "sid" in user_info + else PKCE.generate_random_string(32) + ) # Construct state data to represent the session state_data = StateData( @@ -706,10 +707,7 @@ async def complete_interactive_login( refresh_token=token_response.get("refresh_token"), token_sets=[token_set], domain=origin_domain, - internal={ - "sid": sid, - "created_at": int(time.time()) - } + internal={"sid": sid, "created_at": int(time.time())}, ) # Store the state data in the state store using store_options (Response required) @@ -734,7 +732,9 @@ async def complete_interactive_login( # Methods for retrieving user information, session data, and logout operations. # ============================================================================ - async def get_user(self, store_options: Optional[dict[str, Any]] = None) -> Optional[dict[str, Any]]: + async def get_user( + self, store_options: Optional[dict[str, Any]] = None + ) -> Optional[dict[str, Any]]: """ Retrieves the user from the store, or None if no user found. @@ -763,7 +763,9 @@ async def get_user(self, store_options: Optional[dict[str, Any]] = None) -> Opti return state_data.get("user") return None - async def get_session(self, store_options: Optional[dict[str, Any]] = None) -> Optional[dict[str, Any]]: + async def get_session( + self, store_options: Optional[dict[str, Any]] = None + ) -> Optional[dict[str, Any]]: """ Retrieve the user session from the store, or None if no session found. @@ -789,15 +791,14 @@ async def get_session(self, store_options: Optional[dict[str, Any]] = None) -> O if self._normalize_url(session_domain) != self._normalize_url(current_domain): return None - session_data = {k: v for k, v in state_data.items() - if k != "internal"} + session_data = {k: v for k, v in state_data.items() if k != "internal"} return session_data return None async def logout( self, options: Optional[LogoutOptions] = None, - store_options: Optional[dict[str, Any]] = None + store_options: Optional[dict[str, Any]] = None, ) -> str: options = options or LogoutOptions() @@ -813,19 +814,18 @@ async def logout( if hasattr(state_data, "dict") and callable(state_data.dict): state_data = state_data.dict() session_domain = self._get_session_domain(state_data) - if session_domain and self._normalize_url(session_domain) == self._normalize_url(domain): + if session_domain and self._normalize_url(session_domain) == self._normalize_url( + domain + ): await self._state_store.delete(self._state_identifier, store_options) # Return logout URL for the current resolved domain - logout_url = URL.create_logout_url( - domain, self._client_id, options.return_to) + logout_url = URL.create_logout_url(domain, self._client_id, options.return_to) return logout_url async def handle_backchannel_logout( - self, - logout_token: str, - store_options: Optional[dict[str, Any]] = None + self, logout_token: str, store_options: Optional[dict[str, Any]] = None ) -> None: """ Handles backchannel logout requests. @@ -846,8 +846,7 @@ async def handle_backchannel_logout( # Read iss from unverified token for comparison try: unverified = jwt.decode( - logout_token, algorithms=["RS256"], - options={"verify_signature": False} + logout_token, algorithms=["RS256"], options={"verify_signature": False} ) token_issuer = unverified.get("iss", "") except Exception as e: @@ -876,13 +875,16 @@ async def handle_backchannel_logout( jwks = await self._get_jwks_cached(domain) try: - claims = await self._verify_and_decode_jwt(logout_token, jwks, audience=self._client_id) + claims = await self._verify_and_decode_jwt( + logout_token, jwks, audience=self._client_id + ) # Normalized issuer validation token_issuer = claims.get("iss", "") expected_issuer = self._normalize_url(domain) if self._normalize_url(token_issuer) != self._normalize_url(expected_issuer): - raise IssuerValidationError("Logout token issuer mismatch.Ensure your Auth0 domain is configured correctly." + raise IssuerValidationError( + "Logout token issuer mismatch.Ensure your Auth0 domain is configured correctly." ) except ValueError as e: raise BackchannelLogoutError(str(e)) @@ -891,30 +893,22 @@ async def handle_backchannel_logout( f"Logout token signature verification failed: {str(e)}" ) except jwt.InvalidTokenError as e: - raise BackchannelLogoutError( - f"Logout token verification failed: {str(e)}" - ) + raise BackchannelLogoutError(f"Logout token verification failed: {str(e)}") # Validate the token is a logout token events = claims.get("events", {}) if "http://schemas.openid.net/event/backchannel-logout" not in events: - raise BackchannelLogoutError( - "Invalid logout token: not a backchannel logout event") + raise BackchannelLogoutError("Invalid logout token: not a backchannel logout event") # Delete sessions associated with this token logout_claims = LogoutTokenClaims( - sub=claims.get("sub"), - sid=claims.get("sid"), - iss=claims.get("iss") + sub=claims.get("sub"), sid=claims.get("sid"), iss=claims.get("iss") ) - await self._state_store.delete_by_logout_token( - logout_claims.dict(), store_options - ) + await self._state_store.delete_by_logout_token(logout_claims.dict(), store_options) except (jwt.PyJWTError, ValidationError) as e: - raise BackchannelLogoutError( - f"Error processing logout token: {str(e)}") + raise BackchannelLogoutError(f"Error processing logout token: {str(e)}") # ============================================================================ # ACCESS TOKEN MANAGEMENT @@ -955,13 +949,13 @@ async def get_access_token( if not session_domain: raise AccessTokenError( AccessTokenErrorCode.MISSING_SESSION_DOMAIN, - "Session domain does not match the current domain." + "Session domain does not match the current domain.", ) current_domain = await self._resolve_current_domain(store_options) if self._normalize_url(session_domain) != self._normalize_url(current_domain): raise AccessTokenError( AccessTokenErrorCode.DOMAIN_MISMATCH, - "Session domain does not match the current domain." + "Session domain does not match the current domain.", ) auth_params = self._default_authorization_params or {} @@ -975,7 +969,9 @@ async def get_access_token( # Find matching token set token_set = None if state_data_dict and "token_sets" in state_data_dict: - token_set = self._find_matching_token_set(state_data_dict["token_sets"], audience, merged_scope) + token_set = self._find_matching_token_set( + state_data_dict["token_sets"], audience, merged_scope + ) # If token is valid, return it if token_set and token_set.get("expires_at", 0) > time.time(): @@ -985,7 +981,7 @@ async def get_access_token( if not state_data_dict or not state_data_dict.get("refresh_token"): raise AccessTokenError( AccessTokenErrorCode.MISSING_REFRESH_TOKEN, - "The access token has expired and a refresh token was not provided. The user needs to re-authenticate." + "The access token has expired and a refresh token was not provided. The user needs to re-authenticate.", ) # Get new token with refresh token @@ -994,7 +990,7 @@ async def get_access_token( session_domain = state_data_dict.get("domain") or self._domain get_refresh_token_options = { "refresh_token": state_data_dict["refresh_token"], - "domain": session_domain + "domain": session_domain, } if audience: get_refresh_token_options["audience"] = audience @@ -1002,15 +998,20 @@ async def get_access_token( if merged_scope: get_refresh_token_options["scope"] = merged_scope - token_endpoint_response = await self.get_token_by_refresh_token(get_refresh_token_options) + token_endpoint_response = await self.get_token_by_refresh_token( + get_refresh_token_options + ) # Update state data with new token existing_state_data = await self._state_store.get(self._state_identifier, store_options) updated_state_data = State.update_state_data( - audience, existing_state_data, token_endpoint_response) + audience, existing_state_data, token_endpoint_response + ) # Store updated state - await self._state_store.set(self._state_identifier, updated_state_data, options=store_options) + await self._state_store.set( + self._state_identifier, updated_state_data, options=store_options + ) return token_endpoint_response["access_token"] except Exception as e: @@ -1024,22 +1025,21 @@ async def get_access_token( raw_mfa_token=raw_mfa_token, audience=audience or self.DEFAULT_AUDIENCE_STATE_KEY, scope=merged_scope or "", - mfa_requirements=mfa_requirements + mfa_requirements=mfa_requirements, ) raise MfaRequiredError( "Multifactor authentication required", mfa_token=encrypted_token, - mfa_requirements=mfa_requirements + mfa_requirements=mfa_requirements, ) if isinstance(e, AccessTokenError): raise raise AccessTokenError( AccessTokenErrorCode.REFRESH_TOKEN_ERROR, - f"Failed to get token with refresh token: {str(e)}" + f"Failed to get token with refresh token: {str(e)}", ) - async def get_token_by_refresh_token(self, options: dict[str, Any]) -> dict[str, Any]: """ Retrieves a token by exchanging a refresh token. @@ -1067,8 +1067,7 @@ async def get_token_by_refresh_token(self, options: dict[str, Any]) -> dict[str, token_endpoint = metadata.get("token_endpoint") if not token_endpoint: - raise ApiError("configuration_error", - "Token endpoint missing in OIDC metadata") + raise ApiError("configuration_error", "Token endpoint missing in OIDC metadata") # Prepare the token request parameters token_params = { @@ -1083,8 +1082,7 @@ async def get_token_by_refresh_token(self, options: dict[str, Any]) -> dict[str, # Merge scope if present in options with any in the original authorization params merged_scope = self._merge_scope_with_defaults( - request_scope=options.get("scope"), - audience=audience + request_scope=options.get("scope"), audience=audience ) if merged_scope: @@ -1093,9 +1091,7 @@ async def get_token_by_refresh_token(self, options: dict[str, Any]) -> dict[str, # Exchange the refresh token for an access token async with self._get_http_client() as client: response = await client.post( - token_endpoint, - data=token_params, - auth=(self._client_id, self._client_secret) + token_endpoint, data=token_params, auth=(self._client_id, self._client_secret) ) if response.status_code != 200: @@ -1105,8 +1101,7 @@ async def get_token_by_refresh_token(self, options: dict[str, Any]) -> dict[str, # Preserve mfa_required details for upstream handling if error_code == "mfa_required": error = ApiError( - error_code, - error_data.get("error_description", "MFA required") + error_code, error_data.get("error_description", "MFA required") ) error.mfa_token = error_data.get("mfa_token") mfa_requirements_data = error_data.get("mfa_requirements") @@ -1117,16 +1112,14 @@ async def get_token_by_refresh_token(self, options: dict[str, Any]) -> dict[str, raise ApiError( error_code, - error_data.get("error_description", - "Failed to exchange refresh token") + error_data.get("error_description", "Failed to exchange refresh token"), ) token_response = response.json() # Add required fields if they are missing if "expires_in" in token_response and "expires_at" not in token_response: - token_response["expires_at"] = int( - time.time()) + token_response["expires_in"] + token_response["expires_at"] = int(time.time()) + token_response["expires_in"] return token_response @@ -1136,13 +1129,11 @@ async def get_token_by_refresh_token(self, options: dict[str, Any]) -> dict[str, raise AccessTokenError( AccessTokenErrorCode.REFRESH_TOKEN_ERROR, "The access token has expired and there was an error while trying to refresh it.", - e + e, ) def _merge_scope_with_defaults( - self, - request_scope: Optional[str], - audience: Optional[str] + self, request_scope: Optional[str], audience: Optional[str] ) -> Optional[str]: """Helper: Merges requested scopes with default authorization params.""" audience = audience or self.DEFAULT_AUDIENCE_STATE_KEY @@ -1163,10 +1154,7 @@ def _merge_scope_with_defaults( return " ".join(merged_scopes) if merged_scopes else None def _find_matching_token_set( - self, - token_sets: list[dict[str, Any]], - audience: Optional[str], - scope: Optional[str] + self, token_sets: list[dict[str, Any]], audience: Optional[str], scope: Optional[str] ) -> Optional[dict[str, Any]]: """Helper: Finds a token set matching the requested audience and scopes.""" audience = audience or self.DEFAULT_AUDIENCE_STATE_KEY @@ -1192,9 +1180,7 @@ def _find_matching_token_set( # ============================================================================ async def login_backchannel( - self, - options: dict[str, Any], - store_options: Optional[dict[str, Any]] = None + self, options: dict[str, Any], store_options: Optional[dict[str, Any]] = None ) -> dict[str, Any]: """ Logs in using Client-Initiated Backchannel Authentication. @@ -1213,38 +1199,34 @@ async def login_backchannel( Returns: A dictionary containing the authorizationDetails (when RAR was used). """ - token_endpoint_response = await self.backchannel_authentication({ - "binding_message": options.get("binding_message"), - "login_hint": options.get("login_hint"), - "authorization_params": options.get("authorization_params"), - }, store_options=store_options) + token_endpoint_response = await self.backchannel_authentication( + { + "binding_message": options.get("binding_message"), + "login_hint": options.get("login_hint"), + "authorization_params": options.get("authorization_params"), + }, + store_options=store_options, + ) existing_state_data = await self._state_store.get(self._state_identifier, store_options) audience = self._default_authorization_params.get( - "audience", self.DEFAULT_AUDIENCE_STATE_KEY) - - state_data = State.update_state_data( - audience, - existing_state_data, - token_endpoint_response + "audience", self.DEFAULT_AUDIENCE_STATE_KEY ) + state_data = State.update_state_data(audience, existing_state_data, token_endpoint_response) + # Store domain for MCD session domain = await self._resolve_current_domain(store_options) state_data["domain"] = domain await self._state_store.set(self._state_identifier, state_data, store_options) - result = { - "authorization_details": token_endpoint_response.get("authorization_details") - } + result = {"authorization_details": token_endpoint_response.get("authorization_details")} return result async def backchannel_authentication( - self, - options: dict[str, Any], - store_options: Optional[dict[str, Any]] = None + self, options: dict[str, Any], store_options: Optional[dict[str, Any]] = None ) -> dict[str, Any]: """ Performs backchannel authentication with Auth0. @@ -1269,12 +1251,12 @@ async def backchannel_authentication( Raises: ApiError: If the backchannel authentication fails """ - backchannel_data = await self.initiate_backchannel_authentication(options, store_options=store_options) + backchannel_data = await self.initiate_backchannel_authentication( + options, store_options=store_options + ) auth_req_id = backchannel_data.get("auth_req_id") - expires_in = backchannel_data.get( - "expires_in", 120) # Default to 2 minutes - interval = backchannel_data.get( - "interval", 5) # Default to 5 seconds + expires_in = backchannel_data.get("expires_in", 120) # Default to 2 minutes + interval = backchannel_data.get("interval", 5) # Default to 5 seconds # Calculate when to stop polling end_time = time.time() + expires_in @@ -1283,7 +1265,9 @@ async def backchannel_authentication( while time.time() < end_time: # Make token request try: - token_response = await self.backchannel_authentication_grant(auth_req_id, store_options=store_options) + token_response = await self.backchannel_authentication_grant( + auth_req_id, store_options=store_options + ) return token_response except Exception as e: @@ -1299,17 +1283,14 @@ async def backchannel_authentication( raise ApiError( "backchannel_error", f"Backchannel authentication failed: {str(e) or 'Unknown error'}", - e + e, ) # If we get here, we've timed out - raise ApiError( - "timeout", "Backchannel authentication timed out") + raise ApiError("timeout", "Backchannel authentication timed out") async def initiate_backchannel_authentication( - self, - options: dict[str, Any], - store_options: Optional[dict[str, Any]] = None + self, options: dict[str, Any], store_options: Optional[dict[str, Any]] = None ) -> dict[str, Any]: """ Start backchannel authentication with Auth0. @@ -1339,18 +1320,13 @@ async def initiate_backchannel_authentication( https://auth0.com/docs/get-started/authentication-and-authorization-flow/client-initiated-backchannel-authentication-flow """ - sub = options.get('login_hint', {}).get("sub") + sub = options.get("login_hint", {}).get("sub") if not sub: - raise MissingRequiredArgumentError( - "login_hint.sub" - ) + raise MissingRequiredArgumentError("login_hint.sub") - authorization_params = options.get('authorization_params') + authorization_params = options.get("authorization_params") if authorization_params is not None and not isinstance(authorization_params, dict): - raise ApiError( - "invalid_argument", - "authorization_params must be a dict" - ) + raise ApiError("invalid_argument", "authorization_params must be a dict") if authorization_params: requested_expiry = authorization_params.get("requested_expiry") @@ -1358,7 +1334,7 @@ async def initiate_backchannel_authentication( if not isinstance(requested_expiry, int) or requested_expiry <= 0: raise ApiError( "invalid_argument", - "authorization_params.requested_expiry must be a positive integer" + "authorization_params.requested_expiry must be a positive integer", ) try: @@ -1367,24 +1343,18 @@ async def initiate_backchannel_authentication( metadata = await self._get_oidc_metadata_cached(domain) # Get the issuer from metadata - issuer = metadata.get( - "issuer") or f"https://{domain}/" + issuer = metadata.get("issuer") or f"https://{domain}/" # Get backchannel authentication endpoint - backchannel_endpoint = metadata.get( - "backchannel_authentication_endpoint") + backchannel_endpoint = metadata.get("backchannel_authentication_endpoint") if not backchannel_endpoint: raise ApiError( "configuration_error", - "Backchannel authentication is not supported by the authorization server" + "Backchannel authentication is not supported by the authorization server", ) # Prepare login hint in the required format - login_hint = json.dumps({ - "format": "iss_sub", - "iss": issuer, - "sub": sub - }) + login_hint = json.dumps({"format": "iss_sub", "iss": issuer, "sub": sub}) # The Request Parameters params = { @@ -1394,8 +1364,8 @@ async def initiate_backchannel_authentication( } # Add binding message if provided - if options.get('binding_message'): - params["binding_message"] = options.get('binding_message') + if options.get("binding_message"): + params["binding_message"] = options.get("binding_message") # Add any additional authorization parameters if self._default_authorization_params: @@ -1407,9 +1377,7 @@ async def initiate_backchannel_authentication( # Make the backchannel authentication request async with self._get_http_client() as client: backchannel_response = await client.post( - backchannel_endpoint, - data=params, - auth=(self._client_id, self._client_secret) + backchannel_endpoint, data=params, auth=(self._client_id, self._client_secret) ) if backchannel_response.status_code != 200: @@ -1417,7 +1385,8 @@ async def initiate_backchannel_authentication( raise ApiError( error_data.get("error", "backchannel_error"), error_data.get( - "error_description", "Backchannel authentication request failed") + "error_description", "Backchannel authentication request failed" + ), ) backchannel_data = backchannel_response.json() @@ -1426,7 +1395,7 @@ async def initiate_backchannel_authentication( if not auth_req_id: raise ApiError( "invalid_response", - "Missing auth_req_id in backchannel authentication response" + "Missing auth_req_id in backchannel authentication response", ) return backchannel_data @@ -1437,13 +1406,11 @@ async def initiate_backchannel_authentication( raise ApiError( "backchannel_error", f"Backchannel authentication failed: {str(e) or 'Unknown error'}", - e + e, ) async def backchannel_authentication_grant( - self, - auth_req_id: str, - store_options: Optional[dict[str, Any]] = None + self, auth_req_id: str, store_options: Optional[dict[str, Any]] = None ) -> dict[str, Any]: """ Retrieves a token by exchanging an auth_req_id. @@ -1468,23 +1435,20 @@ async def backchannel_authentication_grant( token_endpoint = metadata.get("token_endpoint") if not token_endpoint: - raise ApiError("configuration_error", - "Token endpoint missing in OIDC metadata") + raise ApiError("configuration_error", "Token endpoint missing in OIDC metadata") # Prepare the token request parameters token_params = { "grant_type": "urn:openid:params:grant-type:ciba", "auth_req_id": auth_req_id, "client_id": self._client_id, - "client_secret": self._client_secret + "client_secret": self._client_secret, } # Exchange the auth_req_id for an access token async with self._get_http_client() as client: response = await client.post( - token_endpoint, - data=token_params, - auth=(self._client_id, self._client_secret) + token_endpoint, data=token_params, auth=(self._client_id, self._client_secret) ) if response.status_code != 200: @@ -1493,23 +1457,18 @@ async def backchannel_authentication_grant( interval = int(retry_after) if retry_after is not None else None raise PollingApiError( error_data.get("error", "auth_req_id_error"), - error_data.get("error_description", - "Failed to exchange auth_req_id"), - interval + error_data.get("error_description", "Failed to exchange auth_req_id"), + interval, ) try: token_response = response.json() except json.JSONDecodeError: - raise ApiError( - "invalid_response", - "Failed to parse token response as JSON" - ) + raise ApiError("invalid_response", "Failed to parse token response as JSON") # Add required fields if they are missing if "expires_in" in token_response and "expires_at" not in token_response: - token_response["expires_at"] = int( - time.time()) + token_response["expires_in"] + token_response["expires_at"] = int(time.time()) + token_response["expires_in"] return token_response @@ -1519,7 +1478,7 @@ async def backchannel_authentication_grant( raise AccessTokenError( AccessTokenErrorCode.AUTH_REQ_ID_ERROR, "There was an error while trying to exchange the auth_req_id for an access token.", - e + e, ) # ============================================================================ @@ -1528,11 +1487,7 @@ async def backchannel_authentication_grant( # to a user's Auth0 profile. # ============================================================================ - async def start_link_user( - self, - options, - store_options: Optional[dict[str, Any]] = None - ): + async def start_link_user(self, options, store_options: Optional[dict[str, Any]] = None): """ Starts the user linking process, and returns a URL to redirect the user-agent to. @@ -1559,13 +1514,9 @@ async def start_link_user( state_data = state_data.dict() session_domain = self._get_session_domain(state_data) if not session_domain: - raise StartLinkUserError( - "Session domain does not match the current domain." - ) + raise StartLinkUserError("Session domain does not match the current domain.") if self._normalize_url(session_domain) != self._normalize_url(origin_domain): - raise StartLinkUserError( - "Session domain does not match the current domain." - ) + raise StartLinkUserError("Session domain does not match the current domain.") # Generate PKCE and state for security code_verifier = PKCE.generate_code_verifier() @@ -1579,7 +1530,7 @@ async def start_link_user( code_verifier=code_verifier, state=state, authorization_params=options.get("authorization_params"), - domain=origin_domain + domain=origin_domain, ) # Store transaction data @@ -1590,17 +1541,13 @@ async def start_link_user( ) await self._transaction_store.set( - f"{self._transaction_identifier}:{state}", - transaction_data, - options=store_options + f"{self._transaction_identifier}:{state}", transaction_data, options=store_options ) return link_user_url async def complete_link_user( - self, - url: str, - store_options: Optional[dict[str, Any]] = None + self, url: str, store_options: Optional[dict[str, Any]] = None ) -> dict[str, Any]: """ Completes the user linking process. @@ -1617,15 +1564,9 @@ async def complete_link_user( result = await self.complete_interactive_login(url, store_options) # Return just the app state as specified - return { - "app_state": result.get("app_state") - } + return {"app_state": result.get("app_state")} - async def start_unlink_user( - self, - options, - store_options: Optional[dict[str, Any]] = None - ): + async def start_unlink_user(self, options, store_options: Optional[dict[str, Any]] = None): """ Starts the user unlinking process, and returns a URL to redirect the user-agent to. @@ -1652,13 +1593,9 @@ async def start_unlink_user( state_data = state_data.dict() session_domain = self._get_session_domain(state_data) if not session_domain: - raise StartLinkUserError( - "Session domain does not match the current domain." - ) + raise StartLinkUserError("Session domain does not match the current domain.") if self._normalize_url(session_domain) != self._normalize_url(origin_domain): - raise StartLinkUserError( - "Session domain does not match the current domain." - ) + raise StartLinkUserError("Session domain does not match the current domain.") # Generate PKCE and state for security code_verifier = PKCE.generate_code_verifier() @@ -1671,7 +1608,7 @@ async def start_unlink_user( code_verifier=code_verifier, state=state, authorization_params=options.get("authorization_params"), - domain=origin_domain + domain=origin_domain, ) # Store transaction data @@ -1682,17 +1619,13 @@ async def start_unlink_user( ) await self._transaction_store.set( - f"{self._transaction_identifier}:{state}", - transaction_data, - options=store_options + f"{self._transaction_identifier}:{state}", transaction_data, options=store_options ) return link_user_url async def complete_unlink_user( - self, - url: str, - store_options: Optional[dict[str, Any]] = None + self, url: str, store_options: Optional[dict[str, Any]] = None ) -> dict[str, Any]: """ Completes the user unlinking process. @@ -1709,9 +1642,7 @@ async def complete_unlink_user( result = await self.complete_interactive_login(url, store_options) # Return just the app state as specified - return { - "app_state": result.get("app_state") - } + return {"app_state": result.get("app_state")} async def _build_link_user_url( self, @@ -1721,7 +1652,7 @@ async def _build_link_user_url( state: str, connection_scope: Optional[str] = None, authorization_params: Optional[dict[str, Any]] = None, - domain: Optional[str] = None + domain: Optional[str] = None, ) -> str: """Build a URL for linking user accounts""" # Generate code challenge from verifier @@ -1732,8 +1663,9 @@ async def _build_link_user_url( metadata = await self._get_oidc_metadata_cached(resolved_domain) # Get authorization endpoint - auth_endpoint = metadata.get("authorization_endpoint", - f"https://{resolved_domain}/authorize") + auth_endpoint = metadata.get( + "authorization_endpoint", f"https://{resolved_domain}/authorize" + ) # Build params params = { @@ -1746,7 +1678,7 @@ async def _build_link_user_url( "response_type": "code", "id_token_hint": id_token, "scope": "openid link_account", - "audience": "my-account" + "audience": "my-account", } # Add connection scope if provided @@ -1765,7 +1697,7 @@ async def _build_unlink_user_url( code_verifier: str, state: str, authorization_params: Optional[dict[str, Any]] = None, - domain: Optional[str] = None + domain: Optional[str] = None, ) -> str: """Build a URL for unlinking user accounts""" # Generate code challenge from verifier @@ -1776,8 +1708,9 @@ async def _build_unlink_user_url( metadata = await self._get_oidc_metadata_cached(resolved_domain) # Get authorization endpoint - auth_endpoint = metadata.get("authorization_endpoint", - f"https://{resolved_domain}/authorize") + auth_endpoint = metadata.get( + "authorization_endpoint", f"https://{resolved_domain}/authorize" + ) # Build params params = { @@ -1789,7 +1722,7 @@ async def _build_unlink_user_url( "response_type": "code", "id_token_hint": id_token, "scope": "openid unlink_account", - "audience": "my-account" + "audience": "my-account", } # Add any additional parameters if authorization_params: @@ -1804,9 +1737,7 @@ async def _build_unlink_user_url( # ============================================================================ async def get_access_token_for_connection( - self, - options: dict[str, Any], - store_options: Optional[dict[str, Any]] = None + self, options: dict[str, Any], store_options: Optional[dict[str, Any]] = None ) -> str: """ Retrieves an access token for a connection. @@ -1840,13 +1771,13 @@ async def get_access_token_for_connection( if not session_domain: raise AccessTokenForConnectionError( AccessTokenForConnectionErrorCode.MISSING_SESSION_DOMAIN, - "Session domain does not match the current domain." + "Session domain does not match the current domain.", ) current_domain = await self._resolve_current_domain(store_options) if self._normalize_url(session_domain) != self._normalize_url(current_domain): raise AccessTokenForConnectionError( AccessTokenForConnectionErrorCode.DOMAIN_MISMATCH, - "Session domain does not match the current domain." + "Session domain does not match the current domain.", ) # Find existing connection token @@ -1865,21 +1796,24 @@ async def get_access_token_for_connection( if not state_data_dict or not state_data_dict.get("refresh_token"): raise AccessTokenForConnectionError( AccessTokenForConnectionErrorCode.MISSING_REFRESH_TOKEN, - "A refresh token was not found but is required to be able to retrieve an access token for a connection." + "A refresh token was not found but is required to be able to retrieve an access token for a connection.", ) # Get new token for connection # Use session's domain for token exchange session_domain = state_data_dict.get("domain") or self._domain - token_endpoint_response = await self.get_token_for_connection({ - "connection": options.get("connection"), - "login_hint": options.get("login_hint"), - "refresh_token": state_data_dict["refresh_token"], - "domain": session_domain - }) + token_endpoint_response = await self.get_token_for_connection( + { + "connection": options.get("connection"), + "login_hint": options.get("login_hint"), + "refresh_token": state_data_dict["refresh_token"], + "domain": session_domain, + } + ) # Update state data with new token updated_state_data = State.update_state_data_for_connection_token_set( - options, state_data_dict, token_endpoint_response) + options, state_data_dict, token_endpoint_response + ) # Store updated state await self._state_store.set(self._state_identifier, updated_state_data, store_options) @@ -1903,8 +1837,12 @@ async def get_token_for_connection(self, options: dict[str, Any]) -> dict[str, A """ # Constants SUBJECT_TYPE_REFRESH_TOKEN = "urn:ietf:params:oauth:token-type:refresh_token" - REQUESTED_TOKEN_TYPE_FEDERATED_CONNECTION_ACCESS_TOKEN = "http://auth0.com/oauth/token-type/federated-connection-access-token" - GRANT_TYPE_FEDERATED_CONNECTION_ACCESS_TOKEN = "urn:auth0:params:oauth:grant-type:token-exchange:federated-connection-access-token" + REQUESTED_TOKEN_TYPE_FEDERATED_CONNECTION_ACCESS_TOKEN = ( + "http://auth0.com/oauth/token-type/federated-connection-access-token" + ) + GRANT_TYPE_FEDERATED_CONNECTION_ACCESS_TOKEN = ( + "urn:auth0:params:oauth:grant-type:token-exchange:federated-connection-access-token" + ) try: # Use session domain if provided, otherwise fallback to static domain domain = options.get("domain") or self._domain @@ -1914,8 +1852,7 @@ async def get_token_for_connection(self, options: dict[str, Any]) -> dict[str, A token_endpoint = metadata.get("token_endpoint") if not token_endpoint: - raise ApiError("configuration_error", - "Token endpoint missing in OIDC metadata") + raise ApiError("configuration_error", "Token endpoint missing in OIDC metadata") # Prepare parameters params = { @@ -1924,7 +1861,7 @@ async def get_token_for_connection(self, options: dict[str, Any]) -> dict[str, A "subject_token": options["refresh_token"], "requested_token_type": REQUESTED_TOKEN_TYPE_FEDERATED_CONNECTION_ACCESS_TOKEN, "grant_type": GRANT_TYPE_FEDERATED_CONNECTION_ACCESS_TOKEN, - "client_id": self._client_id + "client_id": self._client_id, } # Add login_hint if provided @@ -1934,38 +1871,41 @@ async def get_token_for_connection(self, options: dict[str, Any]) -> dict[str, A # Make the request async with self._get_http_client() as client: response = await client.post( - token_endpoint, - data=params, - auth=(self._client_id, self._client_secret) + token_endpoint, data=params, auth=(self._client_id, self._client_secret) ) if response.status_code != 200: - error_data = response.json() if response.headers.get( - "content-type") == "application/json" else {} + error_data = ( + response.json() + if response.headers.get("content-type") == "application/json" + else {} + ) raise ApiError( error_data.get("error", "connection_token_error"), error_data.get( - "error_description", f"Failed to get token for connection: {response.status_code}") + "error_description", + f"Failed to get token for connection: {response.status_code}", + ), ) token_endpoint_response = response.json() return { "access_token": token_endpoint_response.get("access_token"), - "expires_at": int(time.time()) + int(token_endpoint_response.get("expires_in", 3600)), - "scope": token_endpoint_response.get("scope", "") + "expires_at": int(time.time()) + + int(token_endpoint_response.get("expires_in", 3600)), + "scope": token_endpoint_response.get("scope", ""), } except Exception as e: if isinstance(e, ApiError): raise AccessTokenForConnectionError( - AccessTokenForConnectionErrorCode.API_ERROR, - str(e) + AccessTokenForConnectionErrorCode.API_ERROR, str(e) ) raise AccessTokenForConnectionError( AccessTokenForConnectionErrorCode.FETCH_ERROR, "There was an error while trying to retrieve an access token for a connection.", - e + e, ) # ============================================================================ @@ -1975,9 +1915,7 @@ async def get_token_for_connection(self, options: dict[str, Any]) -> dict[str, A # ============================================================================ async def start_connect_account( - self, - options: ConnectAccountOptions, - store_options: dict = None + self, options: ConnectAccountOptions, store_options: dict = None ) -> str: """ Initiates the connect account flow for linking a third-party account to the user's profile. @@ -2002,26 +1940,25 @@ async def start_connect_account( code_verifier = PKCE.generate_code_verifier() code_challenge = PKCE.generate_code_challenge(code_verifier) - state= PKCE.generate_random_string(32) + state = PKCE.generate_random_string(32) connect_request = ConnectAccountRequest( connection=options.connection, scopes=options.scopes, - redirect_uri = redirect_uri, + redirect_uri=redirect_uri, code_challenge=code_challenge, code_challenge_method="S256", state=state, - authorization_params=options.authorization_params + authorization_params=options.authorization_params, ) access_token = await self.get_access_token( audience=self._my_account_client.audience, scope="create:me:connected_accounts", - store_options=store_options + store_options=store_options, ) connect_response = await self._my_account_client.connect_account( - access_token=access_token, - request=connect_request + access_token=access_token, request=connect_request ) # Build the transaction data to store @@ -2029,24 +1966,29 @@ async def start_connect_account( code_verifier=code_verifier, app_state=options.app_state, auth_session=connect_response.auth_session, - redirect_uri=redirect_uri + redirect_uri=redirect_uri, ) # Store the transaction data await self._transaction_store.set( - f"{self._transaction_identifier}:{state}", - transaction_data, - options=store_options + f"{self._transaction_identifier}:{state}", transaction_data, options=store_options ) parsedUrl = urlparse(connect_response.connect_uri) query = urlencode({"ticket": connect_response.connect_params.ticket}) - return urlunparse((parsedUrl.scheme, parsedUrl.netloc, parsedUrl.path, parsedUrl.params, query, parsedUrl.fragment)) + return urlunparse( + ( + parsedUrl.scheme, + parsedUrl.netloc, + parsedUrl.path, + parsedUrl.params, + query, + parsedUrl.fragment, + ) + ) async def complete_connect_account( - self, - url: str, - store_options: dict = None + self, url: str, store_options: dict = None ) -> CompleteConnectAccountResponse: """ Handles the redirect callback to complete the connect account flow for linking a third-party @@ -2078,7 +2020,9 @@ async def complete_connect_account( # Retrieve the transaction data using the state transaction_identifier = f"{self._transaction_identifier}:{state}" - transaction_data = await self._transaction_store.get(transaction_identifier, options=store_options) + transaction_data = await self._transaction_store.get( + transaction_identifier, options=store_options + ) if not transaction_data: raise MissingTransactionError() @@ -2086,18 +2030,19 @@ async def complete_connect_account( access_token = await self.get_access_token( audience=self._my_account_client.audience, scope="create:me:connected_accounts", - store_options=store_options + store_options=store_options, ) request = CompleteConnectAccountRequest( auth_session=transaction_data.auth_session, connect_code=connect_code, redirect_uri=transaction_data.redirect_uri, - code_verifier=transaction_data.code_verifier + code_verifier=transaction_data.code_verifier, ) try: response = await self._my_account_client.complete_connect_account( - access_token=access_token, request=request) + access_token=access_token, request=request + ) if transaction_data.app_state is not None: response.app_state = transaction_data.app_state finally: @@ -2111,7 +2056,7 @@ async def list_connected_accounts( connection: Optional[str] = None, from_param: Optional[str] = None, take: Optional[int] = None, - store_options: dict = None + store_options: dict = None, ) -> ListConnectedAccountsResponse: """ Retrieves a list of connected accounts for the authenticated user. @@ -2135,15 +2080,14 @@ async def list_connected_accounts( access_token = await self.get_access_token( audience=self._my_account_client.audience, scope="read:me:connected_accounts", - store_options=store_options + store_options=store_options, ) return await self._my_account_client.list_connected_accounts( - access_token=access_token, connection=connection, from_param=from_param, take=take) + access_token=access_token, connection=connection, from_param=from_param, take=take + ) async def delete_connected_account( - self, - connected_account_id: str, - store_options: dict = None + self, connected_account_id: str, store_options: dict = None ) -> None: """ Deletes a connected account. @@ -2162,16 +2106,17 @@ async def delete_connected_account( access_token = await self.get_access_token( audience=self._my_account_client.audience, scope="delete:me:connected_accounts", - store_options=store_options + store_options=store_options, ) await self._my_account_client.delete_connected_account( - access_token=access_token, connected_account_id=connected_account_id) + access_token=access_token, connected_account_id=connected_account_id + ) async def list_connected_account_connections( self, from_param: Optional[str] = None, take: Optional[int] = None, - store_options: dict = None + store_options: dict = None, ) -> ListConnectedAccountConnectionsResponse: """ Retrieves a list of available connections that can be used connected accounts for the authenticated user. @@ -2194,10 +2139,11 @@ async def list_connected_account_connections( access_token = await self.get_access_token( audience=self._my_account_client.audience, scope="read:me:connected_accounts", - store_options=store_options + store_options=store_options, ) return await self._my_account_client.list_connected_account_connections( - access_token=access_token, from_param=from_param, take=take) + access_token=access_token, from_param=from_param, take=take + ) # ============================================================================ # CUSTOM TOKEN EXCHANGE (RFC 8693) @@ -2205,9 +2151,7 @@ async def list_connected_account_connections( # ============================================================================ async def custom_token_exchange( - self, - options: CustomTokenExchangeOptions, - store_options: Optional[dict[str, Any]] = None + self, options: CustomTokenExchangeOptions, store_options: Optional[dict[str, Any]] = None ) -> TokenExchangeResponse: """ Exchanges a custom token for Auth0 tokens using RFC 8693. @@ -2280,7 +2224,12 @@ async def custom_token_exchange( # Merge additional authorization params if options.authorization_params: # Prevent override of critical parameters - forbidden_params = {"grant_type", "client_id", "subject_token", "subject_token_type"} + forbidden_params = { + "grant_type", + "client_id", + "subject_token", + "subject_token_type", + } for key, value in options.authorization_params.items(): if key not in forbidden_params: params[key] = value @@ -2288,17 +2237,20 @@ async def custom_token_exchange( # Make the token exchange request async with self._get_http_client() as client: response = await client.post( - token_endpoint, - data=params, - auth=(self._client_id, self._client_secret) + token_endpoint, data=params, auth=(self._client_id, self._client_secret) ) if response.status_code != 200: - error_data = response.json() if response.headers.get( - "content-type", "").startswith("application/json") else {} + error_data = ( + response.json() + if response.headers.get("content-type", "").startswith("application/json") + else {} + ) raise CustomTokenExchangeError( error_data.get("error", CustomTokenExchangeErrorCode.TOKEN_EXCHANGE_FAILED), - error_data.get("error_description", f"Token exchange failed: {response.status_code}") + error_data.get( + "error_description", f"Token exchange failed: {response.status_code}" + ), ) try: @@ -2306,7 +2258,7 @@ async def custom_token_exchange( except json.JSONDecodeError: raise CustomTokenExchangeError( CustomTokenExchangeErrorCode.INVALID_RESPONSE, - "Failed to parse token response as JSON" + "Failed to parse token response as JSON", ) # Validate and return response @@ -2315,7 +2267,7 @@ async def custom_token_exchange( except ValidationError as e: raise CustomTokenExchangeError( CustomTokenExchangeErrorCode.INVALID_TOKEN_FORMAT, - f"Token validation failed: {str(e)}" + f"Token validation failed: {str(e)}", ) except Exception as e: if isinstance(e, (CustomTokenExchangeError, ApiError)): @@ -2323,13 +2275,13 @@ async def custom_token_exchange( raise CustomTokenExchangeError( CustomTokenExchangeErrorCode.TOKEN_EXCHANGE_FAILED, f"Token exchange failed: {str(e)}", - e + e, ) async def login_with_custom_token_exchange( self, options: LoginWithCustomTokenExchangeOptions, - store_options: Optional[dict[str, Any]] = None + store_options: Optional[dict[str, Any]] = None, ) -> LoginWithCustomTokenExchangeResult: """ Performs token exchange and establishes a user session. @@ -2374,10 +2326,12 @@ async def login_with_custom_token_exchange( actor_token=options.actor_token, actor_token_type=options.actor_token_type, organization=options.organization, - authorization_params=options.authorization_params + authorization_params=options.authorization_params, ) - token_response = await self.custom_token_exchange(exchange_options, store_options=store_options) + token_response = await self.custom_token_exchange( + exchange_options, store_options=store_options + ) # Resolve domain and fetch metadata for verification domain = await self._resolve_current_domain(store_options) @@ -2409,28 +2363,18 @@ async def login_with_custom_token_exchange( raise ApiError("jwks_key_not_found", str(e)) except jwt.InvalidSignatureError as e: raise ApiError( - "invalid_signature", - f"ID token signature verification failed: {str(e)}", - e + "invalid_signature", f"ID token signature verification failed: {str(e)}", e ) except jwt.InvalidAudienceError as e: raise ApiError( "invalid_audience", f"ID token audience mismatch. Expected: {self._client_id}: {str(e)}", - e + e, ) except jwt.ExpiredSignatureError as e: - raise ApiError( - "token_expired", - f"ID token has expired: {str(e)}", - e - ) + raise ApiError("token_expired", f"ID token has expired: {str(e)}", e) except jwt.InvalidTokenError as e: - raise ApiError( - "invalid_token", - f"ID token verification failed: {str(e)}", - e - ) + raise ApiError("invalid_token", f"ID token verification failed: {str(e)}", e) # Determine audience for token set audience = options.audience or self.DEFAULT_AUDIENCE_STATE_KEY @@ -2440,7 +2384,7 @@ async def login_with_custom_token_exchange( audience=audience, access_token=token_response.access_token, scope=token_response.scope or options.scope or "", - expires_at=int(time.time()) + token_response.expires_in + expires_at=int(time.time()) + token_response.expires_in, ) # Construct state data @@ -2450,19 +2394,14 @@ async def login_with_custom_token_exchange( refresh_token=token_response.refresh_token, token_sets=[token_set], domain=domain, - internal={ - "sid": sid, - "created_at": int(time.time()) - } + internal={"sid": sid, "created_at": int(time.time())}, ) # Store session await self._state_store.set(self._state_identifier, state_data, options=store_options) # Build result - result = LoginWithCustomTokenExchangeResult( - state_data=state_data.dict() - ) + result = LoginWithCustomTokenExchangeResult(state_data=state_data.dict()) return result @@ -2472,9 +2411,291 @@ async def login_with_custom_token_exchange( raise CustomTokenExchangeError( CustomTokenExchangeErrorCode.TOKEN_EXCHANGE_FAILED, f"Login with custom token exchange failed: {str(e)}", - e + e, ) + # ============================================================================ + # PASSKEY AUTHENTICATION (Category 1) + # ============================================================================ + + GRANT_TYPE_PASSKEY = "urn:okta:params:oauth:grant-type:webauthn" + + async def passkey_signup_challenge( + self, + name: Optional[str] = None, + email: Optional[str] = None, + username: Optional[str] = None, + phone_number: Optional[str] = None, + given_name: Optional[str] = None, + family_name: Optional[str] = None, + nickname: Optional[str] = None, + picture: Optional[str] = None, + user_metadata: Optional[dict[str, Any]] = None, + connection: Optional[str] = None, + organization: Optional[str] = None, + store_options: Optional[dict[str, Any]] = None, + ) -> PasskeySignupChallengeResponse: + """ + Step 1 of 2: Initiate a passkey signup challenge (POST /passkey/register). + + Pass the returned authn_params_public_key to navigator.credentials.create(), + then call signin_with_passkey() with the auth_session and credential result. + + Args: + name: User's full name. + email: User's email address. + username: Username for the new account. + phone_number: User's phone number. + given_name: User's given (first) name. + family_name: User's family (last) name. + nickname: User's nickname. + picture: URL to the user's profile picture. + user_metadata: Arbitrary user metadata dict. + connection: Auth0 database connection name (realm). + organization: Auth0 organization ID or name. + store_options: Optional options for domain resolution. + + Returns: + PasskeySignupChallengeResponse with auth_session and authn_params_public_key. + + Raises: + ApiError: If the challenge request fails. + """ + try: + domain = await self._resolve_current_domain(store_options) + + user_profile: dict[str, Any] = {} + if email is not None: + user_profile["email"] = email + if name is not None: + user_profile["name"] = name + if username is not None: + user_profile["username"] = username + if phone_number is not None: + user_profile["phone_number"] = phone_number + if given_name is not None: + user_profile["given_name"] = given_name + if family_name is not None: + user_profile["family_name"] = family_name + if nickname is not None: + user_profile["nickname"] = nickname + if picture is not None: + user_profile["picture"] = picture + if user_metadata is not None: + user_profile["user_metadata"] = user_metadata + + body: dict[str, Any] = {"client_id": self._client_id} + if self._client_secret: + body["client_secret"] = self._client_secret + if user_profile: + body["user_profile"] = user_profile + if connection: + body["realm"] = connection + if organization: + body["organization"] = organization + + url = f"https://{domain}/passkey/register" + + async with self._get_http_client() as client: + response = await client.post(url, json=body) + + if response.status_code != 200: + try: + error_data = response.json() + except (json.JSONDecodeError, ValueError): + raise ApiError( + "passkey_challenge_error", + f"Passkey signup challenge failed with status {response.status_code}", + ) + raise ApiError( + error_data.get("error", "passkey_challenge_error"), + error_data.get("error_description", "Passkey signup challenge failed"), + ) + + try: + data = response.json() + except (json.JSONDecodeError, ValueError): + raise ApiError( + "invalid_response", + "Failed to parse passkey signup challenge response as JSON", + ) + + return PasskeySignupChallengeResponse.model_validate(data) + + except Exception as e: + if isinstance(e, (ApiError, MissingRequiredArgumentError, ValidationError)): + raise + raise ApiError("passkey_challenge_error", "Passkey signup challenge failed", e) + + async def passkey_login_challenge( + self, + username: Optional[str] = None, + connection: Optional[str] = None, + organization: Optional[str] = None, + store_options: Optional[dict[str, Any]] = None, + ) -> PasskeyLoginChallengeResponse: + """ + Step 1 of 2: Initiate a passkey login challenge (POST /passkey/challenge). + + Pass the returned authn_params_public_key to navigator.credentials.get(), + then call signin_with_passkey() with the auth_session and credential result. + + Args: + username: Optional username hint for conditional UI. + connection: Auth0 database connection name (realm). + organization: Auth0 organization ID or name. + store_options: Optional options for domain resolution. + + Returns: + PasskeyLoginChallengeResponse with auth_session and authn_params_public_key. + + Raises: + ApiError: If the challenge request fails. + """ + try: + domain = await self._resolve_current_domain(store_options) + + body: dict[str, Any] = {"client_id": self._client_id} + if self._client_secret: + body["client_secret"] = self._client_secret + if username: + body["username"] = username + if connection: + body["realm"] = connection + if organization: + body["organization"] = organization + + url = f"https://{domain}/passkey/challenge" + + async with self._get_http_client() as client: + response = await client.post(url, json=body) + + if response.status_code != 200: + try: + error_data = response.json() + except (json.JSONDecodeError, ValueError): + raise ApiError( + "passkey_challenge_error", + f"Passkey login challenge failed with status {response.status_code}", + ) + raise ApiError( + error_data.get("error", "passkey_challenge_error"), + error_data.get("error_description", "Passkey login challenge failed"), + ) + + try: + data = response.json() + except (json.JSONDecodeError, ValueError): + raise ApiError( + "invalid_response", + "Failed to parse passkey login challenge response as JSON", + ) + + return PasskeyLoginChallengeResponse.model_validate(data) + + except Exception as e: + if isinstance(e, (ApiError, MissingRequiredArgumentError, ValidationError)): + raise + raise ApiError("passkey_challenge_error", "Passkey login challenge failed", e) + + async def signin_with_passkey( + self, + auth_session: str, + authn_response: PasskeyAuthResponse, + store_options: Optional[dict[str, Any]] = None, + connection: Optional[str] = None, + organization: Optional[str] = None, + scope: Optional[str] = None, + audience: Optional[str] = None, + ) -> PasskeyTokenResponse: + """ + Completes passkey authentication by exchanging the WebAuthn assertion + for tokens (POST /oauth/token with webauthn grant). + + This is step 2 of 2: call passkey_signup_challenge or passkey_login_challenge + first to obtain auth_session and the WebAuthn challenge options. + + Uses Content-Type: application/json (required for nested authn_response). + + Args: + auth_session: Session credential from passkey_signup_challenge or passkey_login_challenge. + authn_response: Serialized WebAuthn credential from navigator.credentials.create/get. + store_options: Optional options for domain resolution and state store. + connection: Auth0 database connection name (realm). + organization: Auth0 organization ID or name. + scope: OAuth2 scope string. + audience: Target API audience. + + Returns: + PasskeyTokenResponse containing access_token, id_token, expires_in, etc. + + Raises: + MissingRequiredArgumentError: If auth_session or authn_response is missing. + ApiError: If token exchange fails. + """ + if not auth_session: + raise MissingRequiredArgumentError("auth_session") + if authn_response is None: + raise MissingRequiredArgumentError("authn_response") + + try: + domain = await self._resolve_current_domain(store_options) + metadata = await self._get_oidc_metadata_cached(domain) + + token_endpoint = metadata.get("token_endpoint") + if not token_endpoint: + raise ApiError("configuration_error", "Token endpoint missing in OIDC metadata") + + body: dict[str, Any] = { + "grant_type": self.GRANT_TYPE_PASSKEY, + "client_id": self._client_id, + "auth_session": auth_session, + "authn_response": authn_response.model_dump(by_alias=True, exclude_none=True), + } + if self._client_secret: + body["client_secret"] = self._client_secret + if connection: + body["realm"] = connection + if organization: + body["organization"] = organization + if scope: + body["scope"] = scope + if audience: + body["audience"] = audience + + async with self._get_http_client() as client: + response = await client.post(token_endpoint, json=body) + + if response.status_code != 200: + try: + error_data = response.json() + except (json.JSONDecodeError, ValueError): + raise ApiError( + "passkey_token_error", + f"Passkey token exchange failed with status {response.status_code}", + ) + raise ApiError( + error_data.get("error", "passkey_token_error"), + error_data.get("error_description", "Passkey token exchange failed"), + ) + + try: + token_data = response.json() + except (json.JSONDecodeError, ValueError): + raise ApiError( + "invalid_response", "Failed to parse passkey token response as JSON" + ) + + if "expires_in" in token_data and "expires_at" not in token_data: + token_data["expires_at"] = int(time.time()) + token_data["expires_in"] + + return PasskeyTokenResponse.model_validate(token_data) + + except Exception as e: + if isinstance(e, (ApiError, MissingRequiredArgumentError, ValidationError)): + raise + raise ApiError("passkey_token_error", "Passkey sign-in failed", e) + # ============================================================================ # MFA (Multi-Factor Authentication) # ============================================================================ diff --git a/src/auth0_server_python/auth_types/__init__.py b/src/auth0_server_python/auth_types/__init__.py index 055103a..d306efa 100644 --- a/src/auth0_server_python/auth_types/__init__.py +++ b/src/auth0_server_python/auth_types/__init__.py @@ -5,7 +5,7 @@ from typing import Any, Literal, Optional, Union -from pydantic import BaseModel, Field, field_validator, model_validator +from pydantic import BaseModel, ConfigDict, Field, field_validator, model_validator class UserClaims(BaseModel): @@ -13,6 +13,7 @@ class UserClaims(BaseModel): User profile information as returned by Auth0. Contains standard OIDC claims about the authenticated user. """ + sub: str name: Optional[str] = None nickname: Optional[str] = None @@ -32,6 +33,7 @@ class TokenSet(BaseModel): Represents a set of tokens issued by Auth0. Contains the access token and related metadata. """ + audience: str access_token: str scope: Optional[str] = None @@ -43,6 +45,7 @@ class ConnectionTokenSet(TokenSet): Token set specific to a connection. Extends TokenSet with connection-specific information. """ + connection: str login_hint: str @@ -52,6 +55,7 @@ class InternalStateData(BaseModel): Internal data used for managing state. Not meant to be accessed directly by SDK users. """ + sid: str created_at: int @@ -61,6 +65,7 @@ class SessionData(BaseModel): Represents a user session with Auth0. Contains user information and tokens. """ + user: Optional[UserClaims] = None id_token: Optional[str] = None refresh_token: Optional[str] = None @@ -77,6 +82,7 @@ class StateData(SessionData): Complete state data stored in the state store. Extends SessionData with internal management information. """ + internal: InternalStateData @@ -85,6 +91,7 @@ class TransactionData(BaseModel): Represents data for an in-progress authentication transaction. Used during the authorization code flow to correlate requests. """ + audience: Optional[str] = None code_verifier: str app_state: Optional[Any] = None @@ -101,6 +108,7 @@ class LogoutTokenClaims(BaseModel): Claims expected in a logout token. Used for backchannel logout processing. """ + sub: str sid: str iss: Optional[str] = None @@ -111,6 +119,7 @@ class EncryptedStoreOptions(BaseModel): Options for encrypted stores. Contains the secret used for encryption. """ + secret: str @@ -119,6 +128,7 @@ class ServerClientOptionsBase(BaseModel): Base options for configuring the Auth0 server client. Contains core settings required for all clients. """ + domain: str client_id: str client_secret: str @@ -135,6 +145,7 @@ class ServerClientOptionsWithSecret(ServerClientOptionsBase): Client options using a secret for encryption. Extends base options with secret and duration settings. """ + secret: str state_absolute_duration: Optional[int] = 259200 # 3 days in seconds @@ -144,6 +155,7 @@ class StartInteractiveLoginOptions(BaseModel): Options for starting the interactive login process. Configures how the authorization request is constructed. """ + pushed_authorization_requests: Optional[bool] = False app_state: Optional[Any] = None authorization_params: Optional[dict[str, Any]] = None @@ -154,6 +166,7 @@ class LogoutOptions(BaseModel): Options for logout operations. Configures how the logout request is constructed. """ + return_to: Optional[str] = None @@ -162,6 +175,7 @@ class AuthorizationParameters(BaseModel): Parameters used in authorization requests. Based on standard OAuth2/OIDC parameters. """ + scope: Optional[str] = None audience: Optional[str] = None redirect_uri: Optional[str] = None @@ -169,11 +183,13 @@ class AuthorizationParameters(BaseModel): class Config: extra = "allow" # Allow additional OAuth parameters + class AuthorizationDetails(BaseModel): """ Authorization details returned from Auth0. Used for Resource Access Rights (RAR). """ + type: str actions: Optional[list[str]] = None locations: Optional[list[str]] = None @@ -188,6 +204,7 @@ class LoginBackchannelOptions(BaseModel): """ Options for Client-Initiated Backchannel Authentication. """ + binding_message: str login_hint: dict[str, str] # Should contain a 'sub' field authorization_params: Optional[dict[str, Any]] = None @@ -200,6 +217,7 @@ class LoginBackchannelResult(BaseModel): """ Result from Client-Initiated Backchannel Authentication. """ + authorization_details: Optional[list[AuthorizationDetails]] = None @@ -207,19 +225,23 @@ class AccessTokenForConnectionOptions(BaseModel): """ Options for retrieving an access token for a specific connection. """ + connection: str login_hint: Optional[str] = None + class StartLinkUserOptions(BaseModel): connection: str connection_scope: Optional[str] = None authorization_params: Optional[dict[str, Any]] = None app_state: Optional[Any] = None + # ============================================================================= # Multiple Custom Domain # ============================================================================= + class DomainResolverContext(BaseModel): """ Context passed to domain resolver function for MCD support. @@ -236,13 +258,16 @@ async def domain_resolver(context: DomainResolverContext) -> str: host = context.request_headers.get('host', '').split(':')[0] return DOMAIN_MAP.get(host, DEFAULT_DOMAIN) """ + request_url: Optional[str] = None request_headers: Optional[dict[str, str]] = None + # ============================================================================= # Custom Token Exchange Types # ============================================================================= + class CustomTokenExchangeOptions(BaseModel): """ Options for custom token exchange (RFC 8693). @@ -257,6 +282,7 @@ class CustomTokenExchangeOptions(BaseModel): organization: Organization identifier for the token exchange (optional) authorization_params: Additional OAuth parameters (optional) """ + subject_token: str = Field(..., min_length=1) subject_token_type: str = Field(..., min_length=1) audience: Optional[str] = None @@ -266,7 +292,7 @@ class CustomTokenExchangeOptions(BaseModel): organization: Optional[str] = None authorization_params: Optional[dict[str, Any]] = None - @field_validator('subject_token', 'actor_token') + @field_validator("subject_token", "actor_token") @classmethod def validate_token_format(cls, v: Optional[str]) -> Optional[str]: """Validate token doesn't have Bearer prefix and isn't whitespace-only.""" @@ -277,8 +303,8 @@ def validate_token_format(cls, v: Optional[str]) -> Optional[str]: raise ValueError("Token should not include 'Bearer ' prefix") return v - @model_validator(mode='after') - def validate_actor_token_type(self) -> 'CustomTokenExchangeOptions': + @model_validator(mode="after") + def validate_actor_token_type(self) -> "CustomTokenExchangeOptions": """Ensure actor_token_type is provided if actor_token is present.""" if self.actor_token and not self.actor_token_type: raise ValueError("actor_token_type is required when actor_token is provided") @@ -298,6 +324,7 @@ class TokenExchangeResponse(BaseModel): id_token: OpenID Connect ID token (optional) refresh_token: Refresh token (optional) """ + access_token: str token_type: str = "Bearer" expires_in: int @@ -313,6 +340,7 @@ class LoginWithCustomTokenExchangeOptions(BaseModel): Combines token exchange parameters with session management. """ + subject_token: str = Field(..., min_length=1) subject_token_type: str = Field(..., min_length=1) audience: Optional[str] = None @@ -322,7 +350,7 @@ class LoginWithCustomTokenExchangeOptions(BaseModel): organization: Optional[str] = None authorization_params: Optional[dict[str, Any]] = None - @field_validator('subject_token', 'actor_token') + @field_validator("subject_token", "actor_token") @classmethod def validate_token_format(cls, v: Optional[str]) -> Optional[str]: """Validate token doesn't have Bearer prefix and isn't whitespace-only.""" @@ -333,8 +361,8 @@ def validate_token_format(cls, v: Optional[str]) -> Optional[str]: raise ValueError("Token should not include 'Bearer ' prefix") return v - @model_validator(mode='after') - def validate_actor_token_type(self) -> 'LoginWithCustomTokenExchangeOptions': + @model_validator(mode="after") + def validate_actor_token_type(self) -> "LoginWithCustomTokenExchangeOptions": """Ensure actor_token_type is provided if actor_token is present.""" if self.actor_token and not self.actor_token_type: raise ValueError("actor_token_type is required when actor_token is provided") @@ -347,13 +375,16 @@ class LoginWithCustomTokenExchangeResult(BaseModel): Contains session data established after token exchange. """ + state_data: dict[str, Any] authorization_details: Optional[list[AuthorizationDetails]] = None + # ============================================================================= # Connected Accounts Types # ============================================================================= + # BASE & SHARED class ConnectedAccountBase(BaseModel): id: str @@ -363,6 +394,7 @@ class ConnectedAccountBase(BaseModel): created_at: str expires_at: Optional[str] = None + # ENTITIES (What exists) class ConnectedAccount(ConnectedAccountBase): id: str @@ -381,6 +413,7 @@ class ConnectedAccountConnection(BaseModel): # Connect Operations (How to connect) + class ConnectAccountOptions(BaseModel): connection: str redirect_uri: Optional[str] = None @@ -388,43 +421,244 @@ class ConnectAccountOptions(BaseModel): app_state: Optional[Any] = None authorization_params: Optional[dict[str, Any]] = None + class ConnectAccountRequest(BaseModel): connection: str scopes: Optional[list[str]] = None redirect_uri: Optional[str] = None state: Optional[str] = None code_challenge: Optional[str] = None - code_challenge_method: Optional[str] = 'S256' + code_challenge_method: Optional[str] = "S256" authorization_params: Optional[dict[str, Any]] = None + class ConnectParams(BaseModel): ticket: str + class ConnectAccountResponse(BaseModel): auth_session: str connect_uri: str connect_params: ConnectParams expires_in: int + class CompleteConnectAccountRequest(BaseModel): auth_session: str connect_code: str redirect_uri: str code_verifier: Optional[str] = None + class CompleteConnectAccountResponse(ConnectedAccountBase): app_state: Optional[Any] = None + # Manage operations class ListConnectedAccountsResponse(BaseModel): accounts: list[ConnectedAccount] next: Optional[str] = None + class ListConnectedAccountConnectionsResponse(BaseModel): connections: list[ConnectedAccountConnection] next: Optional[str] = None +# ============================================================================= +# Passkey & MyAccount Authentication Methods Types +# ============================================================================= + + +class PasskeyRpInfo(BaseModel): + id: str + name: str + + +class PasskeyUserInfo(BaseModel): + model_config = ConfigDict(populate_by_name=True) + id: str + name: str + display_name: Optional[str] = Field(None, alias="displayName") + + +class PasskeyPubKeyCredParam(BaseModel): + type: str + alg: int + + +class PasskeyAuthenticatorSelection(BaseModel): + model_config = ConfigDict(populate_by_name=True) + resident_key: Optional[str] = Field(None, alias="residentKey") + user_verification: Optional[str] = Field(None, alias="userVerification") + + +class PasskeyPublicKeyOptions(BaseModel): + model_config = ConfigDict(populate_by_name=True) + challenge: str + rp: Optional[PasskeyRpInfo] = None + rp_id: Optional[str] = Field(None, alias="rpId") + user: Optional[PasskeyUserInfo] = None + pub_key_cred_params: Optional[list[PasskeyPubKeyCredParam]] = Field( + None, alias="pubKeyCredParams" + ) + authenticator_selection: Optional[PasskeyAuthenticatorSelection] = Field( + None, alias="authenticatorSelection" + ) + timeout: Optional[int] = None + user_verification: Optional[str] = Field(None, alias="userVerification") + + +class EnrollAuthenticationMethodRequest(BaseModel): + type: str + email: Optional[str] = None + phone_number: Optional[str] = None + preferred_authentication_method: Optional[str] = None + user_identity_id: Optional[str] = None + connection: Optional[str] = None + + +class EnrollmentChallengeResponse(BaseModel): + authentication_method_id: str + auth_session: str + authn_params_public_key: Optional[PasskeyPublicKeyOptions] = None + + def __repr__(self) -> str: + return ( + f"EnrollmentChallengeResponse(" + f"authentication_method_id={self.authentication_method_id!r}, " + f"auth_session=[REDACTED], " + f"authn_params_public_key={self.authn_params_public_key!r})" + ) + + +class PasskeyAuthResponse(BaseModel): + model_config = ConfigDict(populate_by_name=True) + id: str + raw_id: str = Field(alias="rawId") + type: str + authenticator_attachment: Optional[str] = Field(None, alias="authenticatorAttachment") + response: dict[str, str] + client_extension_results: Optional[dict] = Field(None, alias="clientExtensionResults") + + +class VerifyAuthenticationMethodRequest(BaseModel): + auth_session: str + authn_response: Optional[PasskeyAuthResponse] = None + otp_code: Optional[str] = None + recovery_code: Optional[str] = None + password: Optional[str] = None + + @model_validator(mode="after") + def _check_at_least_one_method(self) -> "VerifyAuthenticationMethodRequest": + has_method = ( + self.authn_response is not None + or self.otp_code is not None + or self.recovery_code is not None + or self.password is not None + ) + if not has_method: + raise ValueError( + "At least one verification method must be provided: " + "authn_response, otp_code, recovery_code, or password" + ) + return self + + +class AuthenticationMethod(BaseModel): + model_config = ConfigDict(extra="allow", populate_by_name=True) + + id: str + type: str + created_at: str + confirmed: Optional[bool] = None + usage: Optional[list[str]] = None + identity_user_id: Optional[str] = None + credential_device_type: Optional[str] = None + credential_backed_up: Optional[bool] = None + key_id: Optional[str] = None + public_key: Optional[str] = None + transports: Optional[list[str]] = None + user_agent: Optional[str] = None + user_handle: Optional[str] = None + aaguid: Optional[str] = None + relying_party_id: Optional[str] = None + phone_number: Optional[str] = None + preferred_authentication_method: Optional[str] = None + email: Optional[str] = None + name: Optional[str] = None + last_password_reset: Optional[str] = None + + +class UpdateAuthenticationMethodRequest(BaseModel): + name: Optional[str] = None + preferred_authentication_method: Optional[str] = None + + +class ListAuthenticationMethodsResponse(BaseModel): + authentication_methods: list[AuthenticationMethod] + + +class Factor(BaseModel): + model_config = ConfigDict(extra="allow") + name: str + enabled: Optional[bool] = None + trial_expired: Optional[bool] = None + + +class GetFactorsResponse(BaseModel): + factors: list[Factor] + + +class PasskeySignupChallengeResponse(BaseModel): + model_config = ConfigDict(populate_by_name=True) + auth_session: str + authn_params_public_key: PasskeyPublicKeyOptions + + def __repr__(self) -> str: + return ( + f"PasskeySignupChallengeResponse(" + f"auth_session=[REDACTED], " + f"authn_params_public_key={self.authn_params_public_key!r})" + ) + + +class PasskeyLoginChallengeResponse(BaseModel): + model_config = ConfigDict(populate_by_name=True) + auth_session: str + authn_params_public_key: PasskeyPublicKeyOptions + + def __repr__(self) -> str: + return ( + f"PasskeyLoginChallengeResponse(" + f"auth_session=[REDACTED], " + f"authn_params_public_key={self.authn_params_public_key!r})" + ) + + +class PasskeyTokenResponse(BaseModel): + model_config = ConfigDict(extra="allow", populate_by_name=True) + access_token: str + token_type: str = "Bearer" + expires_in: int + expires_at: Optional[int] = None + scope: Optional[str] = None + id_token: Optional[str] = None + refresh_token: Optional[str] = None + + def __repr__(self) -> str: + return ( + f"PasskeyTokenResponse(" + f"token_type={self.token_type!r}, " + f"expires_in={self.expires_in!r}, " + f"expires_at={self.expires_at!r}, " + f"scope={self.scope!r}, " + f"access_token=[REDACTED], " + f"id_token=[REDACTED], " + f"refresh_token=[REDACTED])" + ) + + # ============================================================================= # MFA Types # ============================================================================= @@ -437,6 +671,7 @@ class ListConnectedAccountConnectionsResponse(BaseModel): class AuthenticatorResponse(BaseModel): """Represents an MFA authenticator enrolled by a user.""" + id: str authenticator_type: AuthenticatorType active: bool @@ -450,14 +685,17 @@ class AuthenticatorResponse(BaseModel): # Enrollment Options + class EnrollOtpOptions(BaseModel): """Options for enrolling an OTP authenticator.""" + authenticator_types: list[str] mfa_token: str class EnrollOobOptions(BaseModel): """Options for enrolling an OOB authenticator (SMS, Voice, Push).""" + authenticator_types: list[str] oob_channels: list[OobChannel] phone_number: Optional[str] = None @@ -466,6 +704,7 @@ class EnrollOobOptions(BaseModel): class EnrollEmailOptions(BaseModel): """Options for enrolling an email authenticator.""" + authenticator_types: list[str] oob_channels: list[OobChannel] email: Optional[str] = None @@ -477,8 +716,10 @@ class EnrollEmailOptions(BaseModel): # Enrollment Responses + class OtpEnrollmentResponse(BaseModel): """Response when enrolling an OTP authenticator.""" + authenticator_type: Literal["otp"] secret: str barcode_uri: str @@ -488,6 +729,7 @@ class OtpEnrollmentResponse(BaseModel): class OobEnrollmentResponse(BaseModel): """Response when enrolling an OOB authenticator.""" + authenticator_type: Literal["oob"] oob_channel: OobChannel oob_code: Optional[str] = None @@ -502,8 +744,10 @@ class OobEnrollmentResponse(BaseModel): # Challenge Types + class ChallengeOptions(BaseModel): """Options for initiating an MFA challenge.""" + challenge_type: ChallengeType authenticator_id: Optional[str] = None mfa_token: str @@ -511,6 +755,7 @@ class ChallengeOptions(BaseModel): class ChallengeResponse(BaseModel): """Response from initiating an MFA challenge.""" + challenge_type: ChallengeType oob_code: Optional[str] = None binding_method: Optional[str] = None @@ -519,21 +764,26 @@ class ChallengeResponse(BaseModel): # List Options + class ListAuthenticatorsOptions(BaseModel): """Options for listing MFA authenticators.""" + mfa_token: str # Verify Types + class VerifyOtpOptions(BaseModel): """Verify with OTP code.""" + mfa_token: str otp: str class VerifyOobOptions(BaseModel): """Verify with OOB code + binding code.""" + mfa_token: str oob_code: str binding_code: str @@ -541,6 +791,7 @@ class VerifyOobOptions(BaseModel): class VerifyRecoveryCodeOptions(BaseModel): """Verify with recovery code.""" + mfa_token: str recovery_code: str @@ -550,6 +801,7 @@ class VerifyRecoveryCodeOptions(BaseModel): class MfaVerifyResponse(BaseModel): """Response from MFA verification.""" + access_token: str token_type: str = "Bearer" expires_in: int @@ -562,24 +814,28 @@ class MfaVerifyResponse(BaseModel): # MFA Requirements + class MfaRequirement(BaseModel): """A single MFA requirement entry.""" + type: str class MfaRequirements(BaseModel): """MFA requirements from an mfa_required error response.""" + enroll: Optional[list[MfaRequirement]] = None challenge: Optional[list[MfaRequirement]] = None # MFA Token Context (for encrypted storage) + class MfaTokenContext(BaseModel): """Internal context stored inside encrypted mfa_token.""" + mfa_token: str audience: str scope: str mfa_requirements: Optional[MfaRequirements] = None created_at: int - diff --git a/src/auth0_server_python/tests/test_dpop_auth.py b/src/auth0_server_python/tests/test_dpop_auth.py new file mode 100644 index 0000000..b6beb69 --- /dev/null +++ b/src/auth0_server_python/tests/test_dpop_auth.py @@ -0,0 +1,145 @@ +import base64 +import hashlib +import json + +import httpx +import pytest +from jwcrypto import jwk + +from auth0_server_python.auth_schemes.bearer_auth import BearerAuth +from auth0_server_python.auth_schemes.dpop_auth import DPoPAuth, _base64url +from auth0_server_python.auth_server.my_account_client import _make_auth + + +@pytest.fixture +def ec_key(): + return jwk.JWK.generate(kty="EC", crv="P-256") + + +def _decode_jwt_parts(token: str) -> tuple[dict, dict]: + parts = token.split(".") + header = json.loads(base64.urlsafe_b64decode(parts[0] + "==")) + payload = json.loads(base64.urlsafe_b64decode(parts[1] + "==")) + return header, payload + + +def test_dpop_headers_set(ec_key): + auth = DPoPAuth(token="test_token", key=ec_key) + request = httpx.Request("POST", "https://example.com/me/v1/authentication-methods") + flow = auth.auth_flow(request) + modified = next(flow) + + assert modified.headers["Authorization"] == "DPoP test_token" + assert "DPoP" in modified.headers + assert "Bearer" not in modified.headers["Authorization"] + + +def test_dpop_proof_structure(ec_key): + auth = DPoPAuth(token="test_token", key=ec_key) + request = httpx.Request("POST", "https://example.com/me/v1/authentication-methods") + flow = auth.auth_flow(request) + modified = next(flow) + + proof = modified.headers["DPoP"] + header, payload = _decode_jwt_parts(proof) + + assert header["typ"] == "dpop+jwt" + assert header["alg"] == "ES256" + assert "jwk" in header + assert header["jwk"]["kty"] == "EC" + assert header["jwk"]["crv"] == "P-256" + + assert "jti" in payload + assert payload["htm"] == "POST" + assert payload["htu"] == "https://example.com/me/v1/authentication-methods" + assert "iat" in payload + assert "ath" in payload + + +def test_dpop_htm_binding(ec_key): + auth = DPoPAuth(token="test_token", key=ec_key) + + get_request = httpx.Request("GET", "https://example.com/me/v1/factors") + flow = auth.auth_flow(get_request) + modified = next(flow) + _, payload = _decode_jwt_parts(modified.headers["DPoP"]) + assert payload["htm"] == "GET" + + post_request = httpx.Request("post", "https://example.com/me/v1/factors") + flow = auth.auth_flow(post_request) + modified = next(flow) + _, payload = _decode_jwt_parts(modified.headers["DPoP"]) + assert payload["htm"] == "POST" + + +def test_dpop_htu_strips_query_and_fragment(ec_key): + auth = DPoPAuth(token="test_token", key=ec_key) + request = httpx.Request("GET", "https://example.com/me/v1/factors?foo=bar#section") + flow = auth.auth_flow(request) + modified = next(flow) + _, payload = _decode_jwt_parts(modified.headers["DPoP"]) + assert payload["htu"] == "https://example.com/me/v1/factors" + + +def test_dpop_htu_preserves_port(ec_key): + auth = DPoPAuth(token="test_token", key=ec_key) + request = httpx.Request("GET", "https://example.com:8443/me/v1/factors") + flow = auth.auth_flow(request) + modified = next(flow) + _, payload = _decode_jwt_parts(modified.headers["DPoP"]) + assert payload["htu"] == "https://example.com:8443/me/v1/factors" + + +def test_dpop_ath_binding(ec_key): + token = "my_access_token_value" + auth = DPoPAuth(token=token, key=ec_key) + request = httpx.Request("GET", "https://example.com/me/v1/factors") + flow = auth.auth_flow(request) + modified = next(flow) + _, payload = _decode_jwt_parts(modified.headers["DPoP"]) + + expected_ath = _base64url(hashlib.sha256(token.encode("ascii")).digest()) + assert payload["ath"] == expected_ath + + +def test_dpop_proof_uniqueness(ec_key): + auth = DPoPAuth(token="test_token", key=ec_key) + jtis = set() + for _ in range(10): + request = httpx.Request("GET", "https://example.com/me/v1/factors") + flow = auth.auth_flow(request) + modified = next(flow) + _, payload = _decode_jwt_parts(modified.headers["DPoP"]) + jtis.add(payload["jti"]) + + assert len(jtis) == 10 + + +def test_dpop_repr_redacts_credentials(ec_key): + auth = DPoPAuth(token="secret_access_token_value", key=ec_key) + assert "secret_access_token_value" not in repr(auth) + assert "secret_access_token_value" not in str(auth) + assert "[REDACTED]" in repr(auth) + assert "[REDACTED]" in str(auth) + + +def test_dpop_rejects_non_ec_key(): + rsa_key = jwk.JWK.generate(kty="RSA", size=2048) + with pytest.raises(ValueError, match="EC P-256"): + DPoPAuth(token="token", key=rsa_key) + + +def test_dpop_rejects_wrong_curve(): + p384_key = jwk.JWK.generate(kty="EC", crv="P-384") + with pytest.raises(ValueError, match="EC P-256"): + DPoPAuth(token="token", key=p384_key) + + +def test_make_auth_bearer_fallback(): + auth = _make_auth("token123", dpop_key=None) + assert isinstance(auth, BearerAuth) + + +def test_make_auth_dpop_when_key_provided(ec_key): + auth = _make_auth("token123", dpop_key=ec_key) + assert isinstance(auth, DPoPAuth) diff --git a/src/auth0_server_python/tests/test_passkey_my_account.py b/src/auth0_server_python/tests/test_passkey_my_account.py new file mode 100644 index 0000000..d7f181d --- /dev/null +++ b/src/auth0_server_python/tests/test_passkey_my_account.py @@ -0,0 +1,473 @@ +from unittest.mock import AsyncMock, MagicMock + +import pytest +from jwcrypto import jwk as jwk_module + +from auth0_server_python.auth_schemes.dpop_auth import DPoPAuth +from auth0_server_python.auth_server.my_account_client import MyAccountClient +from auth0_server_python.auth_types import ( + AuthenticationMethod, + EnrollAuthenticationMethodRequest, + EnrollmentChallengeResponse, + GetFactorsResponse, + ListAuthenticationMethodsResponse, + PasskeyAuthResponse, + UpdateAuthenticationMethodRequest, + VerifyAuthenticationMethodRequest, +) +from auth0_server_python.error import ApiError, MissingRequiredArgumentError, MyAccountApiError + + +@pytest.mark.asyncio +async def test_get_factors_success(mocker): + client = MyAccountClient(domain="auth0.local") + response = AsyncMock() + response.status_code = 200 + response.json = MagicMock(return_value={"factors": [{"name": "sms", "enabled": True}]}) + mocker.patch("httpx.AsyncClient.get", new_callable=AsyncMock, return_value=response) + + result = await client.get_factors(access_token="token123") + + assert isinstance(result, GetFactorsResponse) + assert len(result.factors) == 1 + assert result.factors[0].name == "sms" + assert result.factors[0].enabled is True + + +@pytest.mark.asyncio +@pytest.mark.parametrize("access_token", [None, ""]) +async def test_get_factors_missing_access_token(mocker, access_token): + client = MyAccountClient(domain="auth0.local") + mock_get = mocker.patch("httpx.AsyncClient.get", new_callable=AsyncMock) + + with pytest.raises(MissingRequiredArgumentError): + await client.get_factors(access_token=access_token) + + mock_get.assert_not_awaited() + + +@pytest.mark.asyncio +async def test_get_factors_api_error(mocker): + client = MyAccountClient(domain="auth0.local") + response = AsyncMock() + response.status_code = 403 + response.json = MagicMock( + return_value={ + "title": "Forbidden", + "type": "forbidden", + "detail": "Insufficient scope", + "status": 403, + } + ) + mocker.patch("httpx.AsyncClient.get", new_callable=AsyncMock, return_value=response) + + with pytest.raises(MyAccountApiError) as exc: + await client.get_factors(access_token="token123") + + assert exc.value.status == 403 + + +@pytest.mark.asyncio +async def test_get_factors_network_error(mocker): + client = MyAccountClient(domain="auth0.local") + mocker.patch( + "httpx.AsyncClient.get", new_callable=AsyncMock, side_effect=Exception("Connection refused") + ) + + with pytest.raises(ApiError): + await client.get_factors(access_token="token123") + + +@pytest.mark.asyncio +async def test_get_factors_empty_list(mocker): + client = MyAccountClient(domain="auth0.local") + response = AsyncMock() + response.status_code = 200 + response.json = MagicMock(return_value={"factors": []}) + mocker.patch("httpx.AsyncClient.get", new_callable=AsyncMock, return_value=response) + + result = await client.get_factors(access_token="token123") + assert result.factors == [] + + +@pytest.mark.asyncio +async def test_get_factors_extra_fields(mocker): + client = MyAccountClient(domain="auth0.local") + response = AsyncMock() + response.status_code = 200 + response.json = MagicMock( + return_value={ + "factors": [{"name": "webauthn-roaming", "enabled": True, "future_field": "value"}] + } + ) + mocker.patch("httpx.AsyncClient.get", new_callable=AsyncMock, return_value=response) + + result = await client.get_factors(access_token="token123") + assert result.factors[0].name == "webauthn-roaming" + + +@pytest.mark.asyncio +async def test_list_authentication_methods_success(mocker): + client = MyAccountClient(domain="auth0.local") + response = AsyncMock() + response.status_code = 200 + response.json = MagicMock( + return_value={ + "authentication_methods": [ + { + "id": "am_1", + "type": "passkey", + "created_at": "2026-01-01T00:00:00Z", + "key_id": "kid1", + } + ] + } + ) + mocker.patch("httpx.AsyncClient.get", new_callable=AsyncMock, return_value=response) + + result = await client.list_authentication_methods(access_token="token123") + assert isinstance(result, ListAuthenticationMethodsResponse) + assert len(result.authentication_methods) == 1 + assert result.authentication_methods[0].type == "passkey" + + +@pytest.mark.asyncio +async def test_list_authentication_methods_with_type_filter(mocker): + client = MyAccountClient(domain="auth0.local") + response = AsyncMock() + response.status_code = 200 + response.json = MagicMock(return_value={"authentication_methods": []}) + mock_get = mocker.patch("httpx.AsyncClient.get", new_callable=AsyncMock, return_value=response) + + await client.list_authentication_methods(access_token="token123", type_filter="passkey") + mock_get.assert_awaited_once() + call_kwargs = mock_get.call_args[1] + assert call_kwargs["params"] == {"type": "passkey"} + + +@pytest.mark.asyncio +async def test_list_authentication_methods_empty(mocker): + client = MyAccountClient(domain="auth0.local") + response = AsyncMock() + response.status_code = 200 + response.json = MagicMock(return_value={"authentication_methods": []}) + mocker.patch("httpx.AsyncClient.get", new_callable=AsyncMock, return_value=response) + + result = await client.list_authentication_methods(access_token="token123") + assert result.authentication_methods == [] + + +@pytest.mark.asyncio +async def test_get_authentication_method_success(mocker): + client = MyAccountClient(domain="auth0.local") + response = AsyncMock() + response.status_code = 200 + response.json = MagicMock( + return_value={"id": "am_1", "type": "passkey", "created_at": "2026-01-01T00:00:00Z"} + ) + mocker.patch("httpx.AsyncClient.get", new_callable=AsyncMock, return_value=response) + + result = await client.get_authentication_method( + access_token="token123", authentication_method_id="am_1" + ) + assert isinstance(result, AuthenticationMethod) + assert result.id == "am_1" + + +@pytest.mark.asyncio +@pytest.mark.parametrize("method_id", [None, ""]) +async def test_get_authentication_method_missing_id(mocker, method_id): + client = MyAccountClient(domain="auth0.local") + mock_get = mocker.patch("httpx.AsyncClient.get", new_callable=AsyncMock) + + with pytest.raises(MissingRequiredArgumentError): + await client.get_authentication_method( + access_token="token123", authentication_method_id=method_id + ) + + mock_get.assert_not_awaited() + + +@pytest.mark.asyncio +async def test_get_authentication_method_path_traversal(mocker): + client = MyAccountClient(domain="auth0.local") + response = AsyncMock() + response.status_code = 200 + response.json = MagicMock( + return_value={"id": "id/slash", "type": "passkey", "created_at": "2026-01-01T00:00:00Z"} + ) + mock_get = mocker.patch("httpx.AsyncClient.get", new_callable=AsyncMock, return_value=response) + + await client.get_authentication_method( + access_token="token123", authentication_method_id="id/slash" + ) + call_url = mock_get.call_args[1]["url"] + assert "id%2Fslash" in call_url + assert "id/slash" not in call_url.replace("https://auth0.local/me/", "") + + +@pytest.mark.asyncio +async def test_get_authentication_method_pipe_encoding(mocker): + client = MyAccountClient(domain="auth0.local") + response = AsyncMock() + response.status_code = 200 + response.json = MagicMock( + return_value={"id": "passkey|new", "type": "passkey", "created_at": "2026-01-01T00:00:00Z"} + ) + mock_get = mocker.patch("httpx.AsyncClient.get", new_callable=AsyncMock, return_value=response) + + await client.get_authentication_method( + access_token="token123", authentication_method_id="passkey|new" + ) + call_url = mock_get.call_args[1]["url"] + assert "passkey%7Cnew" in call_url + + +@pytest.mark.asyncio +async def test_delete_authentication_method_success(mocker): + client = MyAccountClient(domain="auth0.local") + response = AsyncMock() + response.status_code = 204 + mocker.patch("httpx.AsyncClient.delete", new_callable=AsyncMock, return_value=response) + + result = await client.delete_authentication_method( + access_token="token123", authentication_method_id="am_1" + ) + assert result is None + + +@pytest.mark.asyncio +@pytest.mark.parametrize("method_id", [None, ""]) +async def test_delete_authentication_method_missing_id(mocker, method_id): + client = MyAccountClient(domain="auth0.local") + mock_delete = mocker.patch("httpx.AsyncClient.delete", new_callable=AsyncMock) + + with pytest.raises(MissingRequiredArgumentError): + await client.delete_authentication_method( + access_token="token123", authentication_method_id=method_id + ) + + mock_delete.assert_not_awaited() + + +@pytest.mark.asyncio +async def test_update_authentication_method_success(mocker): + client = MyAccountClient(domain="auth0.local") + response = AsyncMock() + response.status_code = 200 + response.json = MagicMock( + return_value={ + "id": "am_1", + "type": "passkey", + "created_at": "2026-01-01T00:00:00Z", + "name": "My Key", + } + ) + mock_patch = mocker.patch( + "httpx.AsyncClient.patch", new_callable=AsyncMock, return_value=response + ) + + req = UpdateAuthenticationMethodRequest(name="My Key") + result = await client.update_authentication_method( + access_token="token123", authentication_method_id="am_1", request=req + ) + assert result.name == "My Key" + call_kwargs = mock_patch.call_args[1] + assert call_kwargs["json"] == {"name": "My Key"} + + +@pytest.mark.asyncio +async def test_update_authentication_method_missing_request(mocker): + client = MyAccountClient(domain="auth0.local") + mocker.patch("httpx.AsyncClient.patch", new_callable=AsyncMock) + + with pytest.raises(MissingRequiredArgumentError): + await client.update_authentication_method( + access_token="token123", authentication_method_id="am_1", request=None + ) + + +@pytest.mark.asyncio +async def test_enroll_authentication_method_success(mocker): + client = MyAccountClient(domain="auth0.local") + response = AsyncMock() + response.status_code = 201 + response.headers = {"location": "/me/v1/authentication-methods/passkey|new"} + response.json = MagicMock( + return_value={ + "auth_session": "session_abc", + "authn_params_public_key": { + "challenge": "dGVzdA", + "rp": {"id": "auth0.local", "name": "My App"}, + "user": {"id": "dXNlcl8x", "name": "user@test.com", "displayName": "Test User"}, + "pubKeyCredParams": [{"type": "public-key", "alg": -7}], + "authenticatorSelection": { + "residentKey": "required", + "userVerification": "preferred", + }, + "timeout": 60000, + }, + } + ) + mocker.patch("httpx.AsyncClient.post", new_callable=AsyncMock, return_value=response) + + req = EnrollAuthenticationMethodRequest(type="passkey") + result = await client.enroll_authentication_method(access_token="token123", request=req) + + assert isinstance(result, EnrollmentChallengeResponse) + assert result.authentication_method_id == "passkey|new" + assert result.auth_session == "session_abc" + assert result.authn_params_public_key is not None + assert result.authn_params_public_key.pub_key_cred_params[0].alg == -7 + assert result.authn_params_public_key.authenticator_selection.resident_key == "required" + assert result.authn_params_public_key.user.display_name == "Test User" + + +@pytest.mark.asyncio +async def test_enroll_authentication_method_missing_location(mocker): + client = MyAccountClient(domain="auth0.local") + response = AsyncMock() + response.status_code = 201 + response.headers = {} + response.json = MagicMock(return_value={"auth_session": "session_abc"}) + mocker.patch("httpx.AsyncClient.post", new_callable=AsyncMock, return_value=response) + + req = EnrollAuthenticationMethodRequest(type="passkey") + with pytest.raises(ApiError) as exc: + await client.enroll_authentication_method(access_token="token123", request=req) + + assert "Location header" in str(exc.value) + + +@pytest.mark.asyncio +async def test_enroll_authentication_method_location_with_query(mocker): + client = MyAccountClient(domain="auth0.local") + response = AsyncMock() + response.status_code = 201 + response.headers = {"location": "/me/v1/authentication-methods/abc123?tracking=1"} + response.json = MagicMock(return_value={"auth_session": "session_abc"}) + mocker.patch("httpx.AsyncClient.post", new_callable=AsyncMock, return_value=response) + + req = EnrollAuthenticationMethodRequest(type="passkey") + result = await client.enroll_authentication_method(access_token="token123", request=req) + assert result.authentication_method_id == "abc123" + + +@pytest.mark.asyncio +async def test_enroll_authentication_method_location_absolute_url(mocker): + client = MyAccountClient(domain="auth0.local") + response = AsyncMock() + response.status_code = 201 + response.headers = {"location": "https://tenant.auth0.com/me/v1/authentication-methods/am_xyz"} + response.json = MagicMock(return_value={"auth_session": "session_abc"}) + mocker.patch("httpx.AsyncClient.post", new_callable=AsyncMock, return_value=response) + + req = EnrollAuthenticationMethodRequest(type="passkey") + result = await client.enroll_authentication_method(access_token="token123", request=req) + assert result.authentication_method_id == "am_xyz" + + +@pytest.mark.asyncio +async def test_verify_authentication_method_success(mocker): + client = MyAccountClient(domain="auth0.local") + response = AsyncMock() + response.status_code = 200 + response.json = MagicMock( + return_value={ + "id": "am_1", + "type": "passkey", + "created_at": "2026-01-01T00:00:00Z", + "confirmed": True, + } + ) + mock_post = mocker.patch( + "httpx.AsyncClient.post", new_callable=AsyncMock, return_value=response + ) + + authn_response = PasskeyAuthResponse( + id="cred1", + raw_id="cmF3MQ", + type="public-key", + authenticator_attachment="platform", + response={"clientDataJSON": "abc", "attestationObject": "def"}, + ) + req = VerifyAuthenticationMethodRequest( + auth_session="session_abc", authn_response=authn_response + ) + result = await client.verify_authentication_method( + access_token="token123", authentication_method_id="passkey|new", request=req + ) + + assert isinstance(result, AuthenticationMethod) + assert result.confirmed is True + + call_kwargs = mock_post.call_args[1] + body = call_kwargs["json"] + assert "rawId" in body["authn_response"] + assert "raw_id" not in body["authn_response"] + assert "authenticatorAttachment" in body["authn_response"] + assert body["auth_session"] == "session_abc" + assert "passkey%7Cnew" in call_kwargs["url"] + + +@pytest.mark.asyncio +@pytest.mark.parametrize("method_id", [None, ""]) +async def test_verify_authentication_method_missing_id(mocker, method_id): + client = MyAccountClient(domain="auth0.local") + mocker.patch("httpx.AsyncClient.post", new_callable=AsyncMock) + + req = VerifyAuthenticationMethodRequest(auth_session="session_abc", otp_code="123456") + with pytest.raises(MissingRequiredArgumentError): + await client.verify_authentication_method( + access_token="token123", authentication_method_id=method_id, request=req + ) + + +@pytest.mark.asyncio +async def test_enrollment_challenge_response_repr(): + resp = EnrollmentChallengeResponse( + authentication_method_id="am_1", + auth_session="super_secret_session", + authn_params_public_key=None, + ) + repr_str = repr(resp) + assert "super_secret_session" not in repr_str + assert "[REDACTED]" in repr_str + assert "am_1" in repr_str + + +def test_verify_request_requires_at_least_one_method(): + with pytest.raises(Exception, match="At least one verification method"): + VerifyAuthenticationMethodRequest(auth_session="session_abc") + + +def test_verify_request_accepts_otp_code(): + req = VerifyAuthenticationMethodRequest(auth_session="session_abc", otp_code="123456") + assert req.otp_code == "123456" + + +def test_verify_request_accepts_authn_response(): + authn_resp = PasskeyAuthResponse( + id="cred1", + raw_id="cmF3MQ", + type="public-key", + response={"clientDataJSON": "abc", "attestationObject": "def"}, + ) + req = VerifyAuthenticationMethodRequest(auth_session="session_abc", authn_response=authn_resp) + assert req.authn_response is not None + + +@pytest.mark.asyncio +async def test_get_factors_with_dpop_key(mocker): + client = MyAccountClient(domain="auth0.local") + response = AsyncMock() + response.status_code = 200 + response.json = MagicMock(return_value={"factors": [{"name": "sms", "enabled": True}]}) + mock_get = mocker.patch("httpx.AsyncClient.get", new_callable=AsyncMock, return_value=response) + + dpop_key = jwk_module.JWK.generate(kty="EC", crv="P-256") + await client.get_factors(access_token="token123", dpop_key=dpop_key) + + mock_get.assert_awaited_once() + call_kwargs = mock_get.call_args[1] + assert isinstance(call_kwargs["auth"], DPoPAuth) diff --git a/src/auth0_server_python/tests/test_passkey_server_client.py b/src/auth0_server_python/tests/test_passkey_server_client.py new file mode 100644 index 0000000..8d39410 --- /dev/null +++ b/src/auth0_server_python/tests/test_passkey_server_client.py @@ -0,0 +1,523 @@ +import time +from unittest.mock import AsyncMock + +import httpx +import pytest + +from auth0_server_python.auth_server.server_client import ServerClient +from auth0_server_python.auth_types import ( + PasskeyAuthResponse, + PasskeyLoginChallengeResponse, + PasskeySignupChallengeResponse, + PasskeyTokenResponse, +) +from auth0_server_python.error import ApiError, MissingRequiredArgumentError + + +@pytest.fixture +def server_client(): + return ServerClient( + domain="auth0.local", + client_id="test_client_id", + client_secret="test_client_secret", + state_store=AsyncMock(), + transaction_store=AsyncMock(), + secret="test-secret-value", + ) + + +SIGNUP_CHALLENGE_RESPONSE = { + "auth_session": "session_abc123", + "authn_params_public_key": { + "challenge": "dGVzdC1jaGFsbGVuZ2U", + "rp": {"id": "auth0.local", "name": "Test App"}, + "user": {"id": "dXNlcl8x", "name": "user@example.com", "displayName": "Jane"}, + "pubKeyCredParams": [{"type": "public-key", "alg": -7}], + "authenticatorSelection": { + "residentKey": "required", + "userVerification": "preferred", + }, + "timeout": 60000, + }, +} + +LOGIN_CHALLENGE_RESPONSE = { + "auth_session": "session_login_xyz", + "authn_params_public_key": { + "challenge": "bG9naW4tY2hhbGxlbmdl", + "rpId": "auth0.local", + "timeout": 60000, + "userVerification": "preferred", + }, +} + +TOKEN_RESPONSE = { + "access_token": "at_passkey_123", + "id_token": "eyJ.test.jwt", + "token_type": "Bearer", + "expires_in": 86400, + "scope": "openid profile", +} + + +def _mock_response(status_code=200, json_data=None, headers=None): + resp = httpx.Response( + status_code=status_code, + json=json_data, + headers=headers or {}, + request=httpx.Request("POST", "https://auth0.local/passkey/register"), + ) + return resp + + +# ============================================================================= +# passkey_signup_challenge +# ============================================================================= + + +@pytest.mark.asyncio +async def test_passkey_signup_challenge_success(server_client, mocker): + mock_response = _mock_response(200, SIGNUP_CHALLENGE_RESPONSE) + mock_client = AsyncMock() + mock_client.post = AsyncMock(return_value=mock_response) + mock_client.__aenter__ = AsyncMock(return_value=mock_client) + mock_client.__aexit__ = AsyncMock(return_value=None) + mocker.patch.object(server_client, "_get_http_client", return_value=mock_client) + + result = await server_client.passkey_signup_challenge( + email="user@example.com", + name="Jane Doe", + connection="Username-Password-Authentication", + ) + + assert isinstance(result, PasskeySignupChallengeResponse) + assert result.auth_session == "session_abc123" + assert result.authn_params_public_key.challenge == "dGVzdC1jaGFsbGVuZ2U" + assert result.authn_params_public_key.rp.id == "auth0.local" + assert result.authn_params_public_key.user.display_name == "Jane" + assert result.authn_params_public_key.pub_key_cred_params[0].alg == -7 + assert result.authn_params_public_key.authenticator_selection.resident_key == "required" + + call_args = mock_client.post.call_args + assert "/passkey/register" in call_args.args[0] + body = call_args.kwargs["json"] + assert body["client_id"] == "test_client_id" + assert body["client_secret"] == "test_client_secret" + assert body["user_profile"]["email"] == "user@example.com" + assert body["user_profile"]["name"] == "Jane Doe" + assert body["realm"] == "Username-Password-Authentication" + + +@pytest.mark.asyncio +async def test_passkey_signup_challenge_user_profile_fields(server_client, mocker): + mock_response = _mock_response(200, SIGNUP_CHALLENGE_RESPONSE) + mock_client = AsyncMock() + mock_client.post = AsyncMock(return_value=mock_response) + mock_client.__aenter__ = AsyncMock(return_value=mock_client) + mock_client.__aexit__ = AsyncMock(return_value=None) + mocker.patch.object(server_client, "_get_http_client", return_value=mock_client) + + await server_client.passkey_signup_challenge( + email="u@e.com", + username="jdoe", + phone_number="+1234567890", + given_name="Jane", + family_name="Doe", + nickname="jd", + picture="https://example.com/pic.jpg", + user_metadata={"role": "admin"}, + organization="org_123", + ) + + body = mock_client.post.call_args.kwargs["json"] + assert body["user_profile"]["email"] == "u@e.com" + assert body["user_profile"]["username"] == "jdoe" + assert body["user_profile"]["phone_number"] == "+1234567890" + assert body["user_profile"]["given_name"] == "Jane" + assert body["user_profile"]["family_name"] == "Doe" + assert body["user_profile"]["nickname"] == "jd" + assert body["user_profile"]["picture"] == "https://example.com/pic.jpg" + assert body["user_profile"]["user_metadata"] == {"role": "admin"} + assert body["organization"] == "org_123" + + +@pytest.mark.asyncio +async def test_passkey_signup_challenge_minimal_body(server_client, mocker): + mock_response = _mock_response(200, SIGNUP_CHALLENGE_RESPONSE) + mock_client = AsyncMock() + mock_client.post = AsyncMock(return_value=mock_response) + mock_client.__aenter__ = AsyncMock(return_value=mock_client) + mock_client.__aexit__ = AsyncMock(return_value=None) + mocker.patch.object(server_client, "_get_http_client", return_value=mock_client) + + await server_client.passkey_signup_challenge() + + body = mock_client.post.call_args.kwargs["json"] + assert body == {"client_id": "test_client_id", "client_secret": "test_client_secret"} + assert "user_profile" not in body + assert "realm" not in body + assert "organization" not in body + + +@pytest.mark.asyncio +async def test_passkey_signup_challenge_api_error(server_client, mocker): + error_resp = _mock_response( + 403, + {"error": "access_denied", "error_description": "Passkey not enabled"}, + ) + mock_client = AsyncMock() + mock_client.post = AsyncMock(return_value=error_resp) + mock_client.__aenter__ = AsyncMock(return_value=mock_client) + mock_client.__aexit__ = AsyncMock(return_value=None) + mocker.patch.object(server_client, "_get_http_client", return_value=mock_client) + + with pytest.raises(ApiError) as exc: + await server_client.passkey_signup_challenge(email="test@example.com") + assert "access_denied" in str(exc.value) or "Passkey not enabled" in str(exc.value) + + +@pytest.mark.asyncio +async def test_passkey_signup_challenge_non_json_error(server_client, mocker): + resp = httpx.Response( + status_code=502, + content=b"Bad Gateway", + headers={"content-type": "text/html"}, + request=httpx.Request("POST", "https://auth0.local/passkey/register"), + ) + mock_client = AsyncMock() + mock_client.post = AsyncMock(return_value=resp) + mock_client.__aenter__ = AsyncMock(return_value=mock_client) + mock_client.__aexit__ = AsyncMock(return_value=None) + mocker.patch.object(server_client, "_get_http_client", return_value=mock_client) + + with pytest.raises(ApiError) as exc: + await server_client.passkey_signup_challenge() + assert "502" in str(exc.value) or "passkey_challenge_error" in str(exc.value) + + +@pytest.mark.asyncio +async def test_passkey_signup_challenge_network_error(server_client, mocker): + mock_client = AsyncMock() + mock_client.post = AsyncMock(side_effect=Exception("Connection refused")) + mock_client.__aenter__ = AsyncMock(return_value=mock_client) + mock_client.__aexit__ = AsyncMock(return_value=None) + mocker.patch.object(server_client, "_get_http_client", return_value=mock_client) + + with pytest.raises(ApiError) as exc: + await server_client.passkey_signup_challenge() + assert "Passkey signup challenge failed" in str(exc.value) + + +# ============================================================================= +# passkey_login_challenge +# ============================================================================= + + +@pytest.mark.asyncio +async def test_passkey_login_challenge_success(server_client, mocker): + mock_response = _mock_response(200, LOGIN_CHALLENGE_RESPONSE) + mock_client = AsyncMock() + mock_client.post = AsyncMock(return_value=mock_response) + mock_client.__aenter__ = AsyncMock(return_value=mock_client) + mock_client.__aexit__ = AsyncMock(return_value=None) + mocker.patch.object(server_client, "_get_http_client", return_value=mock_client) + + result = await server_client.passkey_login_challenge( + connection="Username-Password-Authentication", + organization="org_abc", + ) + + assert isinstance(result, PasskeyLoginChallengeResponse) + assert result.auth_session == "session_login_xyz" + assert result.authn_params_public_key.challenge == "bG9naW4tY2hhbGxlbmdl" + assert result.authn_params_public_key.rp_id == "auth0.local" + assert result.authn_params_public_key.user_verification == "preferred" + + body = mock_client.post.call_args.kwargs["json"] + assert body["client_id"] == "test_client_id" + assert body["realm"] == "Username-Password-Authentication" + assert body["organization"] == "org_abc" + + +@pytest.mark.asyncio +async def test_passkey_login_challenge_with_username(server_client, mocker): + mock_response = _mock_response(200, LOGIN_CHALLENGE_RESPONSE) + mock_client = AsyncMock() + mock_client.post = AsyncMock(return_value=mock_response) + mock_client.__aenter__ = AsyncMock(return_value=mock_client) + mock_client.__aexit__ = AsyncMock(return_value=None) + mocker.patch.object(server_client, "_get_http_client", return_value=mock_client) + + await server_client.passkey_login_challenge(username="jane@example.com") + + body = mock_client.post.call_args.kwargs["json"] + assert body["username"] == "jane@example.com" + + +@pytest.mark.asyncio +async def test_passkey_login_challenge_api_error(server_client, mocker): + error_resp = _mock_response( + 400, + {"error": "invalid_request", "error_description": "Missing client_id"}, + ) + mock_client = AsyncMock() + mock_client.post = AsyncMock(return_value=error_resp) + mock_client.__aenter__ = AsyncMock(return_value=mock_client) + mock_client.__aexit__ = AsyncMock(return_value=None) + mocker.patch.object(server_client, "_get_http_client", return_value=mock_client) + + with pytest.raises(ApiError): + await server_client.passkey_login_challenge() + + +@pytest.mark.asyncio +async def test_passkey_login_challenge_network_error(server_client, mocker): + mock_client = AsyncMock() + mock_client.post = AsyncMock(side_effect=Exception("timeout")) + mock_client.__aenter__ = AsyncMock(return_value=mock_client) + mock_client.__aexit__ = AsyncMock(return_value=None) + mocker.patch.object(server_client, "_get_http_client", return_value=mock_client) + + with pytest.raises(ApiError): + await server_client.passkey_login_challenge() + + +# ============================================================================= +# signin_with_passkey +# ============================================================================= + + +@pytest.fixture +def authn_response(): + return PasskeyAuthResponse( + id="cred_abc123", + raw_id="Y3JlZF9hYmMxMjM", + type="public-key", + authenticator_attachment="platform", + response={ + "clientDataJSON": "eyJ0eXBlIjoid2ViYXV0aG4uZ2V0In0", + "authenticatorData": "SZYN5YgOjGh0NBcPZHZgW4_krrmihjLHmVzzuoMdl2M", + "signature": "MEUCIQC", + "userHandle": "dXNlcl8x", + }, + ) + + +@pytest.mark.asyncio +async def test_signin_with_passkey_success(server_client, authn_response, mocker): + mock_response = _mock_response(200, TOKEN_RESPONSE) + mock_client = AsyncMock() + mock_client.post = AsyncMock(return_value=mock_response) + mock_client.__aenter__ = AsyncMock(return_value=mock_client) + mock_client.__aexit__ = AsyncMock(return_value=None) + mocker.patch.object(server_client, "_get_http_client", return_value=mock_client) + mocker.patch.object( + server_client, + "_get_oidc_metadata_cached", + return_value={"token_endpoint": "https://auth0.local/oauth/token"}, + ) + + result = await server_client.signin_with_passkey( + auth_session="session_xyz", + authn_response=authn_response, + scope="openid profile", + audience="https://api.example.com", + connection="Username-Password-Authentication", + organization="org_abc", + ) + + assert isinstance(result, PasskeyTokenResponse) + assert result.access_token == "at_passkey_123" + assert result.token_type == "Bearer" + assert abs(result.expires_at - (int(time.time()) + 86400)) <= 2 + + body = mock_client.post.call_args.kwargs["json"] + assert body["grant_type"] == "urn:okta:params:oauth:grant-type:webauthn" + assert body["client_id"] == "test_client_id" + assert body["client_secret"] == "test_client_secret" + assert body["auth_session"] == "session_xyz" + assert body["scope"] == "openid profile" + assert body["audience"] == "https://api.example.com" + assert body["realm"] == "Username-Password-Authentication" + assert body["organization"] == "org_abc" + assert body["authn_response"]["rawId"] == "Y3JlZF9hYmMxMjM" + assert body["authn_response"]["authenticatorAttachment"] == "platform" + assert "raw_id" not in body["authn_response"] + + +@pytest.mark.asyncio +async def test_signin_with_passkey_uses_json_content_type(server_client, authn_response, mocker): + mock_response = _mock_response(200, TOKEN_RESPONSE) + mock_client = AsyncMock() + mock_client.post = AsyncMock(return_value=mock_response) + mock_client.__aenter__ = AsyncMock(return_value=mock_client) + mock_client.__aexit__ = AsyncMock(return_value=None) + mocker.patch.object(server_client, "_get_http_client", return_value=mock_client) + mocker.patch.object( + server_client, + "_get_oidc_metadata_cached", + return_value={"token_endpoint": "https://auth0.local/oauth/token"}, + ) + + await server_client.signin_with_passkey( + auth_session="s", + authn_response=authn_response, + ) + + call_kwargs = mock_client.post.call_args.kwargs + assert "json" in call_kwargs + assert "data" not in call_kwargs + + +@pytest.mark.asyncio +@pytest.mark.parametrize("auth_session", [None, ""]) +async def test_signin_with_passkey_missing_auth_session( + server_client, authn_response, auth_session +): + with pytest.raises(MissingRequiredArgumentError): + await server_client.signin_with_passkey( + auth_session=auth_session, + authn_response=authn_response, + ) + + +@pytest.mark.asyncio +async def test_signin_with_passkey_missing_authn_response(server_client): + with pytest.raises(MissingRequiredArgumentError): + await server_client.signin_with_passkey( + auth_session="session_abc", + authn_response=None, + ) + + +@pytest.mark.asyncio +async def test_signin_with_passkey_api_error(server_client, authn_response, mocker): + error_resp = _mock_response( + 401, + {"error": "invalid_grant", "error_description": "Invalid auth_session"}, + ) + mock_client = AsyncMock() + mock_client.post = AsyncMock(return_value=error_resp) + mock_client.__aenter__ = AsyncMock(return_value=mock_client) + mock_client.__aexit__ = AsyncMock(return_value=None) + mocker.patch.object(server_client, "_get_http_client", return_value=mock_client) + mocker.patch.object( + server_client, + "_get_oidc_metadata_cached", + return_value={"token_endpoint": "https://auth0.local/oauth/token"}, + ) + + with pytest.raises(ApiError) as exc: + await server_client.signin_with_passkey( + auth_session="expired_session", + authn_response=authn_response, + ) + assert "invalid_grant" in str(exc.value) or "Invalid auth_session" in str(exc.value) + + +@pytest.mark.asyncio +async def test_signin_with_passkey_missing_token_endpoint(server_client, authn_response, mocker): + mocker.patch.object( + server_client, + "_get_oidc_metadata_cached", + return_value={}, + ) + + with pytest.raises(ApiError) as exc: + await server_client.signin_with_passkey( + auth_session="session", + authn_response=authn_response, + ) + assert "token endpoint" in str(exc.value).lower() + + +@pytest.mark.asyncio +async def test_signin_with_passkey_network_error(server_client, authn_response, mocker): + mock_client = AsyncMock() + mock_client.post = AsyncMock(side_effect=Exception("Connection reset")) + mock_client.__aenter__ = AsyncMock(return_value=mock_client) + mock_client.__aexit__ = AsyncMock(return_value=None) + mocker.patch.object(server_client, "_get_http_client", return_value=mock_client) + mocker.patch.object( + server_client, + "_get_oidc_metadata_cached", + return_value={"token_endpoint": "https://auth0.local/oauth/token"}, + ) + + with pytest.raises(ApiError): + await server_client.signin_with_passkey( + auth_session="session", + authn_response=authn_response, + ) + + +@pytest.mark.asyncio +async def test_signin_with_passkey_no_client_secret(mocker): + client = ServerClient( + domain="auth0.local", + client_id="public_client", + client_secret=None, + state_store=AsyncMock(), + transaction_store=AsyncMock(), + secret="test-secret", + ) + + mock_response = _mock_response(200, TOKEN_RESPONSE) + mock_client = AsyncMock() + mock_client.post = AsyncMock(return_value=mock_response) + mock_client.__aenter__ = AsyncMock(return_value=mock_client) + mock_client.__aexit__ = AsyncMock(return_value=None) + mocker.patch.object(client, "_get_http_client", return_value=mock_client) + mocker.patch.object( + client, + "_get_oidc_metadata_cached", + return_value={"token_endpoint": "https://auth0.local/oauth/token"}, + ) + + authn_resp = PasskeyAuthResponse( + id="cred", + raw_id="cmF3", + type="public-key", + response={"clientDataJSON": "abc", "authenticatorData": "def", "signature": "ghi"}, + ) + + await client.signin_with_passkey( + auth_session="session", + authn_response=authn_resp, + ) + + body = mock_client.post.call_args.kwargs["json"] + assert "client_secret" not in body + assert body["client_id"] == "public_client" + + +@pytest.mark.asyncio +async def test_signup_challenge_repr_redacts_auth_session(): + resp = PasskeySignupChallengeResponse.model_validate(SIGNUP_CHALLENGE_RESPONSE) + repr_str = repr(resp) + assert "session_abc123" not in repr_str + assert "[REDACTED]" in repr_str + + +@pytest.mark.asyncio +async def test_login_challenge_repr_redacts_auth_session(): + resp = PasskeyLoginChallengeResponse.model_validate(LOGIN_CHALLENGE_RESPONSE) + repr_str = repr(resp) + assert "session_login_xyz" not in repr_str + assert "[REDACTED]" in repr_str + + +def test_passkey_token_response_repr_redacts_tokens(): + resp = PasskeyTokenResponse( + access_token="secret_at_value", + token_type="Bearer", + expires_in=86400, + id_token="secret_id_token", + refresh_token="secret_rt_value", + ) + repr_str = repr(resp) + assert "secret_at_value" not in repr_str + assert "secret_id_token" not in repr_str + assert "secret_rt_value" not in repr_str + assert "[REDACTED]" in repr_str + assert "86400" in repr_str From ec66c0f30db6d51f3319fb9d03fc55a109e80fa7 Mon Sep 17 00:00:00 2001 From: Sourav Basu Date: Tue, 2 Jun 2026 14:12:56 +0530 Subject: [PATCH 2/9] Added missing test cases and edge case fix --- .../auth_server/my_account_client.py | 12 +- .../tests/test_passkey_my_account.py | 357 ++++++++++++++++++ .../tests/test_passkey_server_client.py | 62 +++ 3 files changed, 427 insertions(+), 4 deletions(-) diff --git a/src/auth0_server_python/auth_server/my_account_client.py b/src/auth0_server_python/auth_server/my_account_client.py index a6aed8f..9089186 100644 --- a/src/auth0_server_python/auth_server/my_account_client.py +++ b/src/auth0_server_python/auth_server/my_account_client.py @@ -643,10 +643,14 @@ async def enroll_authentication_method( "Enrollment succeeded (201) but Location header is missing", ) - authentication_method_id = ( - location.split("?")[0].split("#")[0].rstrip("/").split("/")[-1] - ) - if not authentication_method_id: + path = location.split("?")[0].split("#")[0].rstrip("/") + segments = path.split("/") + authentication_method_id = segments[-1] if len(segments) > 1 else "" + if not authentication_method_id or authentication_method_id in ( + "authentication-methods", + "v1", + "me", + ): raise ApiError( "enroll_authentication_method_error", "Enrollment succeeded (201) but could not extract ID from Location header", diff --git a/src/auth0_server_python/tests/test_passkey_my_account.py b/src/auth0_server_python/tests/test_passkey_my_account.py index d7f181d..4b7f29d 100644 --- a/src/auth0_server_python/tests/test_passkey_my_account.py +++ b/src/auth0_server_python/tests/test_passkey_my_account.py @@ -471,3 +471,360 @@ async def test_get_factors_with_dpop_key(mocker): mock_get.assert_awaited_once() call_kwargs = mock_get.call_args[1] assert isinstance(call_kwargs["auth"], DPoPAuth) + + +# ============================================================================= +# DPoP integration(mock) tests +# ============================================================================= + + +@pytest.mark.asyncio +async def test_list_authentication_methods_with_dpop_key(mocker): + client = MyAccountClient(domain="auth0.local") + response = AsyncMock() + response.status_code = 200 + response.json = MagicMock(return_value={"authentication_methods": []}) + mock_get = mocker.patch("httpx.AsyncClient.get", new_callable=AsyncMock, return_value=response) + + dpop_key = jwk_module.JWK.generate(kty="EC", crv="P-256") + await client.list_authentication_methods(access_token="token123", dpop_key=dpop_key) + + assert isinstance(mock_get.call_args[1]["auth"], DPoPAuth) + + +@pytest.mark.asyncio +async def test_get_authentication_method_with_dpop_key(mocker): + client = MyAccountClient(domain="auth0.local") + response = AsyncMock() + response.status_code = 200 + response.json = MagicMock( + return_value={"id": "am_1", "type": "passkey", "created_at": "2026-01-01T00:00:00Z"} + ) + mock_get = mocker.patch("httpx.AsyncClient.get", new_callable=AsyncMock, return_value=response) + + dpop_key = jwk_module.JWK.generate(kty="EC", crv="P-256") + await client.get_authentication_method( + access_token="token123", authentication_method_id="am_1", dpop_key=dpop_key + ) + + assert isinstance(mock_get.call_args[1]["auth"], DPoPAuth) + + +@pytest.mark.asyncio +async def test_delete_authentication_method_with_dpop_key(mocker): + client = MyAccountClient(domain="auth0.local") + response = AsyncMock() + response.status_code = 204 + mock_delete = mocker.patch( + "httpx.AsyncClient.delete", new_callable=AsyncMock, return_value=response + ) + + dpop_key = jwk_module.JWK.generate(kty="EC", crv="P-256") + await client.delete_authentication_method( + access_token="token123", authentication_method_id="am_1", dpop_key=dpop_key + ) + + assert isinstance(mock_delete.call_args[1]["auth"], DPoPAuth) + + +@pytest.mark.asyncio +async def test_update_authentication_method_with_dpop_key(mocker): + client = MyAccountClient(domain="auth0.local") + response = AsyncMock() + response.status_code = 200 + response.json = MagicMock( + return_value={"id": "am_1", "type": "passkey", "created_at": "2026-01-01T00:00:00Z"} + ) + mock_patch = mocker.patch( + "httpx.AsyncClient.patch", new_callable=AsyncMock, return_value=response + ) + + dpop_key = jwk_module.JWK.generate(kty="EC", crv="P-256") + req = UpdateAuthenticationMethodRequest(name="New Name") + await client.update_authentication_method( + access_token="token123", authentication_method_id="am_1", request=req, dpop_key=dpop_key + ) + + assert isinstance(mock_patch.call_args[1]["auth"], DPoPAuth) + + +@pytest.mark.asyncio +async def test_enroll_authentication_method_with_dpop_key(mocker): + client = MyAccountClient(domain="auth0.local") + response = AsyncMock() + response.status_code = 201 + response.headers = {"location": "/me/v1/authentication-methods/passkey|new"} + response.json = MagicMock(return_value={"auth_session": "session_abc"}) + mock_post = mocker.patch( + "httpx.AsyncClient.post", new_callable=AsyncMock, return_value=response + ) + + dpop_key = jwk_module.JWK.generate(kty="EC", crv="P-256") + req = EnrollAuthenticationMethodRequest(type="passkey") + await client.enroll_authentication_method( + access_token="token123", request=req, dpop_key=dpop_key + ) + + assert isinstance(mock_post.call_args[1]["auth"], DPoPAuth) + + +@pytest.mark.asyncio +async def test_verify_authentication_method_with_dpop_key(mocker): + client = MyAccountClient(domain="auth0.local") + response = AsyncMock() + response.status_code = 200 + response.json = MagicMock( + return_value={"id": "am_1", "type": "passkey", "created_at": "2026-01-01T00:00:00Z"} + ) + mock_post = mocker.patch( + "httpx.AsyncClient.post", new_callable=AsyncMock, return_value=response + ) + + dpop_key = jwk_module.JWK.generate(kty="EC", crv="P-256") + req = VerifyAuthenticationMethodRequest(auth_session="session_abc", otp_code="123456") + await client.verify_authentication_method( + access_token="token123", + authentication_method_id="am_1", + request=req, + dpop_key=dpop_key, + ) + + assert isinstance(mock_post.call_args[1]["auth"], DPoPAuth) + + +# ============================================================================= +# API error and network error tests +# ============================================================================= + + +@pytest.mark.asyncio +async def test_list_authentication_methods_api_error(mocker): + client = MyAccountClient(domain="auth0.local") + response = AsyncMock() + response.status_code = 403 + response.json = MagicMock( + return_value={ + "title": "Forbidden", + "type": "forbidden", + "detail": "Insufficient scope", + "status": 403, + } + ) + mocker.patch("httpx.AsyncClient.get", new_callable=AsyncMock, return_value=response) + + with pytest.raises(MyAccountApiError) as exc: + await client.list_authentication_methods(access_token="token123") + assert exc.value.status == 403 + + +@pytest.mark.asyncio +async def test_list_authentication_methods_network_error(mocker): + client = MyAccountClient(domain="auth0.local") + mocker.patch( + "httpx.AsyncClient.get", new_callable=AsyncMock, side_effect=Exception("Connection refused") + ) + + with pytest.raises(ApiError): + await client.list_authentication_methods(access_token="token123") + + +@pytest.mark.asyncio +async def test_get_authentication_method_api_error(mocker): + client = MyAccountClient(domain="auth0.local") + response = AsyncMock() + response.status_code = 404 + response.json = MagicMock( + return_value={ + "title": "Not Found", + "type": "not_found", + "detail": "Not found", + "status": 404, + } + ) + mocker.patch("httpx.AsyncClient.get", new_callable=AsyncMock, return_value=response) + + with pytest.raises(MyAccountApiError) as exc: + await client.get_authentication_method( + access_token="token123", authentication_method_id="am_1" + ) + assert exc.value.status == 404 + + +@pytest.mark.asyncio +async def test_get_authentication_method_network_error(mocker): + client = MyAccountClient(domain="auth0.local") + mocker.patch("httpx.AsyncClient.get", new_callable=AsyncMock, side_effect=Exception("timeout")) + + with pytest.raises(ApiError): + await client.get_authentication_method( + access_token="token123", authentication_method_id="am_1" + ) + + +@pytest.mark.asyncio +async def test_delete_authentication_method_api_error(mocker): + client = MyAccountClient(domain="auth0.local") + response = AsyncMock() + response.status_code = 404 + response.json = MagicMock( + return_value={ + "title": "Not Found", + "type": "not_found", + "detail": "Not found", + "status": 404, + } + ) + mocker.patch("httpx.AsyncClient.delete", new_callable=AsyncMock, return_value=response) + + with pytest.raises(MyAccountApiError) as exc: + await client.delete_authentication_method( + access_token="token123", authentication_method_id="am_1" + ) + assert exc.value.status == 404 + + +@pytest.mark.asyncio +async def test_delete_authentication_method_network_error(mocker): + client = MyAccountClient(domain="auth0.local") + mocker.patch( + "httpx.AsyncClient.delete", + new_callable=AsyncMock, + side_effect=Exception("Connection reset"), + ) + + with pytest.raises(ApiError): + await client.delete_authentication_method( + access_token="token123", authentication_method_id="am_1" + ) + + +@pytest.mark.asyncio +async def test_update_authentication_method_api_error(mocker): + client = MyAccountClient(domain="auth0.local") + response = AsyncMock() + response.status_code = 422 + response.json = MagicMock( + return_value={ + "title": "Unprocessable", + "type": "validation_error", + "detail": "Invalid", + "status": 422, + } + ) + mocker.patch("httpx.AsyncClient.patch", new_callable=AsyncMock, return_value=response) + + req = UpdateAuthenticationMethodRequest(name="x") + with pytest.raises(MyAccountApiError) as exc: + await client.update_authentication_method( + access_token="token123", authentication_method_id="am_1", request=req + ) + assert exc.value.status == 422 + + +@pytest.mark.asyncio +async def test_update_authentication_method_network_error(mocker): + client = MyAccountClient(domain="auth0.local") + mocker.patch( + "httpx.AsyncClient.patch", new_callable=AsyncMock, side_effect=Exception("timeout") + ) + + req = UpdateAuthenticationMethodRequest(name="x") + with pytest.raises(ApiError): + await client.update_authentication_method( + access_token="token123", authentication_method_id="am_1", request=req + ) + + +@pytest.mark.asyncio +async def test_enroll_authentication_method_api_error(mocker): + client = MyAccountClient(domain="auth0.local") + response = AsyncMock() + response.status_code = 403 + response.json = MagicMock( + return_value={ + "title": "Forbidden", + "type": "forbidden", + "detail": "Scope missing", + "status": 403, + } + ) + mocker.patch("httpx.AsyncClient.post", new_callable=AsyncMock, return_value=response) + + req = EnrollAuthenticationMethodRequest(type="passkey") + with pytest.raises(MyAccountApiError) as exc: + await client.enroll_authentication_method(access_token="token123", request=req) + assert exc.value.status == 403 + + +@pytest.mark.asyncio +async def test_enroll_authentication_method_network_error(mocker): + client = MyAccountClient(domain="auth0.local") + mocker.patch( + "httpx.AsyncClient.post", + new_callable=AsyncMock, + side_effect=Exception("Connection refused"), + ) + + req = EnrollAuthenticationMethodRequest(type="passkey") + with pytest.raises(ApiError): + await client.enroll_authentication_method(access_token="token123", request=req) + + +@pytest.mark.asyncio +async def test_verify_authentication_method_api_error(mocker): + client = MyAccountClient(domain="auth0.local") + response = AsyncMock() + response.status_code = 400 + response.json = MagicMock( + return_value={ + "title": "Bad Request", + "type": "invalid_request", + "detail": "Invalid OTP", + "status": 400, + } + ) + mocker.patch("httpx.AsyncClient.post", new_callable=AsyncMock, return_value=response) + + req = VerifyAuthenticationMethodRequest(auth_session="session_abc", otp_code="000000") + with pytest.raises(MyAccountApiError) as exc: + await client.verify_authentication_method( + access_token="token123", authentication_method_id="am_1", request=req + ) + assert exc.value.status == 400 + + +@pytest.mark.asyncio +async def test_verify_authentication_method_network_error(mocker): + client = MyAccountClient(domain="auth0.local") + mocker.patch( + "httpx.AsyncClient.post", + new_callable=AsyncMock, + side_effect=Exception("Connection refused"), + ) + + req = VerifyAuthenticationMethodRequest(auth_session="session_abc", otp_code="123456") + with pytest.raises(ApiError): + await client.verify_authentication_method( + access_token="token123", authentication_method_id="am_1", request=req + ) + + +# ============================================================================= +# Location header extraction edge case +# ============================================================================= + + +@pytest.mark.asyncio +async def test_enroll_authentication_method_location_collection_url(mocker): + """Rejects Location header that ends at collection path without resource ID.""" + client = MyAccountClient(domain="auth0.local") + response = AsyncMock() + response.status_code = 201 + response.headers = {"location": "/me/v1/authentication-methods/"} + response.json = MagicMock(return_value={"auth_session": "session_abc"}) + mocker.patch("httpx.AsyncClient.post", new_callable=AsyncMock, return_value=response) + + req = EnrollAuthenticationMethodRequest(type="passkey") + with pytest.raises(ApiError) as exc: + await client.enroll_authentication_method(access_token="token123", request=req) + assert "could not extract ID" in str(exc.value) diff --git a/src/auth0_server_python/tests/test_passkey_server_client.py b/src/auth0_server_python/tests/test_passkey_server_client.py index 8d39410..7c2be37 100644 --- a/src/auth0_server_python/tests/test_passkey_server_client.py +++ b/src/auth0_server_python/tests/test_passkey_server_client.py @@ -521,3 +521,65 @@ def test_passkey_token_response_repr_redacts_tokens(): assert "secret_rt_value" not in repr_str assert "[REDACTED]" in repr_str assert "86400" in repr_str + + +# ============================================================================= +# expires_at edge cases +# ============================================================================= + + +@pytest.mark.asyncio +async def test_signin_with_passkey_preserves_server_expires_at( + server_client, authn_response, mocker +): + token_data = { + "access_token": "at_123", + "token_type": "Bearer", + "expires_in": 3600, + "expires_at": 9999999999, + } + mock_response = _mock_response(200, token_data) + mock_client = AsyncMock() + mock_client.post = AsyncMock(return_value=mock_response) + mock_client.__aenter__ = AsyncMock(return_value=mock_client) + mock_client.__aexit__ = AsyncMock(return_value=None) + mocker.patch.object(server_client, "_get_http_client", return_value=mock_client) + mocker.patch.object( + server_client, + "_get_oidc_metadata_cached", + return_value={"token_endpoint": "https://auth0.local/oauth/token"}, + ) + + result = await server_client.signin_with_passkey( + auth_session="session", authn_response=authn_response + ) + + assert result.expires_at == 9999999999 + + +@pytest.mark.asyncio +async def test_signin_with_passkey_missing_expires_at_calculates( + server_client, authn_response, mocker +): + token_data = { + "access_token": "at_123", + "token_type": "Bearer", + "expires_in": 60, + } + mock_response = _mock_response(200, token_data) + mock_client = AsyncMock() + mock_client.post = AsyncMock(return_value=mock_response) + mock_client.__aenter__ = AsyncMock(return_value=mock_client) + mock_client.__aexit__ = AsyncMock(return_value=None) + mocker.patch.object(server_client, "_get_http_client", return_value=mock_client) + mocker.patch.object( + server_client, + "_get_oidc_metadata_cached", + return_value={"token_endpoint": "https://auth0.local/oauth/token"}, + ) + + result = await server_client.signin_with_passkey( + auth_session="session", authn_response=authn_response + ) + + assert abs(result.expires_at - (int(time.time()) + 60)) <= 2 From d2d1f216d4e159027d4bf85ef60a636fd13d2078 Mon Sep 17 00:00:00 2001 From: Sourav Basu Date: Tue, 2 Jun 2026 15:45:22 +0530 Subject: [PATCH 3/9] Resolved snake case to camel case for correct parsing --- src/auth0_server_python/auth_server/server_client.py | 10 +++++----- .../tests/test_passkey_server_client.py | 9 +++++---- 2 files changed, 10 insertions(+), 9 deletions(-) diff --git a/src/auth0_server_python/auth_server/server_client.py b/src/auth0_server_python/auth_server/server_client.py index 334eb00..d5118c7 100644 --- a/src/auth0_server_python/auth_server/server_client.py +++ b/src/auth0_server_python/auth_server/server_client.py @@ -2472,23 +2472,23 @@ async def passkey_signup_challenge( if username is not None: user_profile["username"] = username if phone_number is not None: - user_profile["phone_number"] = phone_number + user_profile["phoneNumber"] = phone_number if given_name is not None: - user_profile["given_name"] = given_name + user_profile["givenName"] = given_name if family_name is not None: - user_profile["family_name"] = family_name + user_profile["familyName"] = family_name if nickname is not None: user_profile["nickname"] = nickname if picture is not None: user_profile["picture"] = picture - if user_metadata is not None: - user_profile["user_metadata"] = user_metadata body: dict[str, Any] = {"client_id": self._client_id} if self._client_secret: body["client_secret"] = self._client_secret if user_profile: body["user_profile"] = user_profile + if user_metadata is not None: + body["userMetadata"] = user_metadata if connection: body["realm"] = connection if organization: diff --git a/src/auth0_server_python/tests/test_passkey_server_client.py b/src/auth0_server_python/tests/test_passkey_server_client.py index 7c2be37..2d644af 100644 --- a/src/auth0_server_python/tests/test_passkey_server_client.py +++ b/src/auth0_server_python/tests/test_passkey_server_client.py @@ -132,12 +132,13 @@ async def test_passkey_signup_challenge_user_profile_fields(server_client, mocke body = mock_client.post.call_args.kwargs["json"] assert body["user_profile"]["email"] == "u@e.com" assert body["user_profile"]["username"] == "jdoe" - assert body["user_profile"]["phone_number"] == "+1234567890" - assert body["user_profile"]["given_name"] == "Jane" - assert body["user_profile"]["family_name"] == "Doe" + assert body["user_profile"]["phoneNumber"] == "+1234567890" + assert body["user_profile"]["givenName"] == "Jane" + assert body["user_profile"]["familyName"] == "Doe" assert body["user_profile"]["nickname"] == "jd" assert body["user_profile"]["picture"] == "https://example.com/pic.jpg" - assert body["user_profile"]["user_metadata"] == {"role": "admin"} + assert "user_metadata" not in body["user_profile"] + assert body["userMetadata"] == {"role": "admin"} assert body["organization"] == "org_123" From 7691a0c12fead5a655fd58860249a349daeb7c09 Mon Sep 17 00:00:00 2001 From: Sourav Basu Date: Tue, 2 Jun 2026 16:05:59 +0530 Subject: [PATCH 4/9] Reverting to snake case - as per auth0 api docs. --- src/auth0_server_python/auth_server/server_client.py | 10 +++++----- .../tests/test_passkey_server_client.py | 9 ++++----- 2 files changed, 9 insertions(+), 10 deletions(-) diff --git a/src/auth0_server_python/auth_server/server_client.py b/src/auth0_server_python/auth_server/server_client.py index d5118c7..334eb00 100644 --- a/src/auth0_server_python/auth_server/server_client.py +++ b/src/auth0_server_python/auth_server/server_client.py @@ -2472,23 +2472,23 @@ async def passkey_signup_challenge( if username is not None: user_profile["username"] = username if phone_number is not None: - user_profile["phoneNumber"] = phone_number + user_profile["phone_number"] = phone_number if given_name is not None: - user_profile["givenName"] = given_name + user_profile["given_name"] = given_name if family_name is not None: - user_profile["familyName"] = family_name + user_profile["family_name"] = family_name if nickname is not None: user_profile["nickname"] = nickname if picture is not None: user_profile["picture"] = picture + if user_metadata is not None: + user_profile["user_metadata"] = user_metadata body: dict[str, Any] = {"client_id": self._client_id} if self._client_secret: body["client_secret"] = self._client_secret if user_profile: body["user_profile"] = user_profile - if user_metadata is not None: - body["userMetadata"] = user_metadata if connection: body["realm"] = connection if organization: diff --git a/src/auth0_server_python/tests/test_passkey_server_client.py b/src/auth0_server_python/tests/test_passkey_server_client.py index 2d644af..7c2be37 100644 --- a/src/auth0_server_python/tests/test_passkey_server_client.py +++ b/src/auth0_server_python/tests/test_passkey_server_client.py @@ -132,13 +132,12 @@ async def test_passkey_signup_challenge_user_profile_fields(server_client, mocke body = mock_client.post.call_args.kwargs["json"] assert body["user_profile"]["email"] == "u@e.com" assert body["user_profile"]["username"] == "jdoe" - assert body["user_profile"]["phoneNumber"] == "+1234567890" - assert body["user_profile"]["givenName"] == "Jane" - assert body["user_profile"]["familyName"] == "Doe" + assert body["user_profile"]["phone_number"] == "+1234567890" + assert body["user_profile"]["given_name"] == "Jane" + assert body["user_profile"]["family_name"] == "Doe" assert body["user_profile"]["nickname"] == "jd" assert body["user_profile"]["picture"] == "https://example.com/pic.jpg" - assert "user_metadata" not in body["user_profile"] - assert body["userMetadata"] == {"role": "admin"} + assert body["user_profile"]["user_metadata"] == {"role": "admin"} assert body["organization"] == "org_123" From 3233fff43ee6985053b8b261a8838057df6578a2 Mon Sep 17 00:00:00 2001 From: Sourav Basu Date: Tue, 2 Jun 2026 17:15:27 +0530 Subject: [PATCH 5/9] Reverted lint fixes for easier review --- .../auth_server/my_account_client.py | 56 +- .../auth_server/server_client.py | 683 ++++++++++-------- 2 files changed, 410 insertions(+), 329 deletions(-) diff --git a/src/auth0_server_python/auth_server/my_account_client.py b/src/auth0_server_python/auth_server/my_account_client.py index 9089186..bd23a12 100644 --- a/src/auth0_server_python/auth_server/my_account_client.py +++ b/src/auth0_server_python/auth_server/my_account_client.py @@ -74,7 +74,9 @@ def audience(self): return f"https://{self._domain}/me/" async def connect_account( - self, access_token: str, request: ConnectAccountRequest + self, + access_token: str, + request: ConnectAccountRequest ) -> ConnectAccountResponse: """ Initiate the connected account flow. @@ -95,7 +97,7 @@ async def connect_account( response = await client.post( url=f"{self.audience}v1/connected-accounts/connect", json=request.model_dump(exclude_none=True), - auth=BearerAuth(access_token), + auth=BearerAuth(access_token) ) if response.status_code != 201: @@ -105,7 +107,7 @@ async def connect_account( type=error_data.get("type", None), detail=error_data.get("detail", None), status=error_data.get("status", None), - validation_errors=error_data.get("validation_errors", None), + validation_errors=error_data.get("validation_errors", None) ) data = response.json() @@ -118,11 +120,13 @@ async def connect_account( raise ApiError( "connect_account_error", f"Connected Accounts connect request failed: {str(e) or 'Unknown error'}", - e, + e ) async def complete_connect_account( - self, access_token: str, request: CompleteConnectAccountRequest + self, + access_token: str, + request: CompleteConnectAccountRequest ) -> CompleteConnectAccountResponse: """ Complete the connected account flow after user authorization. @@ -143,7 +147,7 @@ async def complete_connect_account( response = await client.post( url=f"{self.audience}v1/connected-accounts/complete", json=request.model_dump(exclude_none=True), - auth=BearerAuth(access_token), + auth=BearerAuth(access_token) ) if response.status_code != 201: @@ -153,7 +157,7 @@ async def complete_connect_account( type=error_data.get("type", None), detail=error_data.get("detail", None), status=error_data.get("status", None), - validation_errors=error_data.get("validation_errors", None), + validation_errors=error_data.get("validation_errors", None) ) data = response.json() @@ -166,7 +170,7 @@ async def complete_connect_account( raise ApiError( "connect_account_error", f"Connected Accounts complete request failed: {str(e) or 'Unknown error'}", - e, + e ) async def list_connected_accounts( @@ -174,7 +178,7 @@ async def list_connected_accounts( access_token: str, connection: Optional[str] = None, from_param: Optional[str] = None, - take: Optional[int] = None, + take: Optional[int] = None ) -> ListConnectedAccountsResponse: """ List connected accounts for the authenticated user. @@ -213,7 +217,7 @@ async def list_connected_accounts( response = await client.get( url=f"{self.audience}v1/connected-accounts/accounts", params=params, - auth=BearerAuth(access_token), + auth=BearerAuth(access_token) ) if response.status_code != 200: @@ -223,7 +227,7 @@ async def list_connected_accounts( type=error_data.get("type", None), detail=error_data.get("detail", None), status=error_data.get("status", None), - validation_errors=error_data.get("validation_errors", None), + validation_errors=error_data.get("validation_errors", None) ) data = response.json() @@ -236,10 +240,15 @@ async def list_connected_accounts( raise ApiError( "connect_account_error", f"Connected Accounts list request failed: {str(e) or 'Unknown error'}", - e, + e ) - async def delete_connected_account(self, access_token: str, connected_account_id: str) -> None: + + async def delete_connected_account( + self, + access_token: str, + connected_account_id: str + ) -> None: """ Delete a connected account for the authenticated user. @@ -266,7 +275,7 @@ async def delete_connected_account(self, access_token: str, connected_account_id async with self._get_http_client() as client: response = await client.delete( url=f"{self.audience}v1/connected-accounts/accounts/{connected_account_id}", - auth=BearerAuth(access_token), + auth=BearerAuth(access_token) ) if response.status_code != 204: @@ -276,7 +285,7 @@ async def delete_connected_account(self, access_token: str, connected_account_id type=error_data.get("type", None), detail=error_data.get("detail", None), status=error_data.get("status", None), - validation_errors=error_data.get("validation_errors", None), + validation_errors=error_data.get("validation_errors", None) ) except Exception as e: @@ -285,11 +294,14 @@ async def delete_connected_account(self, access_token: str, connected_account_id raise ApiError( "connect_account_error", f"Connected Accounts delete request failed: {str(e) or 'Unknown error'}", - e, + e ) async def list_connected_account_connections( - self, access_token: str, from_param: Optional[str] = None, take: Optional[int] = None + self, + access_token: str, + from_param: Optional[str] = None, + take: Optional[int] = None ) -> ListConnectedAccountConnectionsResponse: """ List available connections that support connected accounts. @@ -325,7 +337,7 @@ async def list_connected_account_connections( response = await client.get( url=f"{self.audience}v1/connected-accounts/connections", params=params, - auth=BearerAuth(access_token), + auth=BearerAuth(access_token) ) if response.status_code != 200: @@ -335,7 +347,7 @@ async def list_connected_account_connections( type=error_data.get("type", None), detail=error_data.get("detail", None), status=error_data.get("status", None), - validation_errors=error_data.get("validation_errors", None), + validation_errors=error_data.get("validation_errors", None) ) data = response.json() @@ -348,9 +360,13 @@ async def list_connected_account_connections( raise ApiError( "connect_account_error", f"Connected Accounts list connections request failed: {str(e) or 'Unknown error'}", - e, + e ) + # ============================================================================ + # AUTHENTICATION METHODS & FACTORS (Passkey / MyAccount API) + # ============================================================================ + async def get_factors( self, access_token: str, diff --git a/src/auth0_server_python/auth_server/server_client.py b/src/auth0_server_python/auth_server/server_client.py index 334eb00..a233205 100644 --- a/src/auth0_server_python/auth_server/server_client.py +++ b/src/auth0_server_python/auth_server/server_client.py @@ -69,18 +69,11 @@ ) # Generic type for store options -TStoreOptions = TypeVar("TStoreOptions") +TStoreOptions = TypeVar('TStoreOptions') # redirect_uri is intentionally excluded — in MCD mode it is built # dynamically from the resolved domain at login time. -INTERNAL_AUTHORIZE_PARAMS = [ - "client_id", - "response_type", - "code_challenge", - "code_challenge_method", - "state", - "nonce", - "scope", -] +INTERNAL_AUTHORIZE_PARAMS = ["client_id", "response_type", + "code_challenge", "code_challenge_method", "state", "nonce", "scope"] class ServerClient(Generic[TStoreOptions]): @@ -88,7 +81,6 @@ class ServerClient(Generic[TStoreOptions]): Main client for Auth0 server SDK. Handles authentication flows, session management, and token operations using Authlib for OIDC functionality. """ - DEFAULT_AUDIENCE_STATE_KEY = "default" # ============================================================================ @@ -129,7 +121,9 @@ def __init__( raise MissingRequiredArgumentError("secret") if domain is None: - raise ConfigurationError("Domain is required") + raise ConfigurationError( + "Domain is required" + ) # Validate domain type if not isinstance(domain, str) and not callable(domain): @@ -174,12 +168,14 @@ def __init__( headers=self._telemetry_headers, ) - self._my_account_client = MyAccountClient(domain=domain, headers=self._telemetry_headers) + self._my_account_client = MyAccountClient( + domain=domain, headers=self._telemetry_headers + ) # Unified cache for OIDC metadata and JWKS per domain (LRU eviction + TTL) self._discovery_cache: OrderedDict[str, dict] = OrderedDict() - self._cache_ttl = 600 # 10 mins. TTL - self._cache_max_entries = 100 # Max 100 domains + self._cache_ttl = 600 # 10 mins. TTL + self._cache_max_entries = 100 # Max 100 domains # Initialize MFA client self._mfa_client = MfaClient( @@ -206,14 +202,14 @@ def _normalize_url(self, value: str) -> str: return value value = value.lower() - if value.startswith("https://"): + if value.startswith('https://'): pass - elif value.startswith("http://"): - value = value.replace("http://", "https://") + elif value.startswith('http://'): + value = value.replace('http://', 'https://') else: - value = f"https://{value}" + value = f'https://{value}' - return value.rstrip("/") + return value.rstrip('/') async def _resolve_current_domain(self, store_options=None) -> str: """Resolve domain from resolver function or return static domain.""" @@ -226,7 +222,8 @@ async def _resolve_current_domain(self, store_options=None) -> str: raise except Exception as e: raise DomainResolverError( - f"Domain resolver function raised an exception: {str(e)}", original_error=e + f"Domain resolver function raised an exception: {str(e)}", + original_error=e ) return self._domain @@ -240,18 +237,18 @@ def _get_session_domain(self, state_data_dict: dict) -> Optional[str]: 2. self._domain — static domain (if configured) 3. Extract hostname from user.iss — derive from user's issuer claim """ - domain = state_data_dict.get("domain") + domain = state_data_dict.get('domain') if domain: return domain if self._domain: return self._domain - user = state_data_dict.get("user") + user = state_data_dict.get('user') if isinstance(user, dict): - iss = user.get("iss") + iss = user.get('iss') else: - iss = getattr(user, "iss", None) if user else None + iss = getattr(user, 'iss', None) if user else None if iss: parsed = urlparse(iss) @@ -354,7 +351,7 @@ async def _get_oidc_metadata_cached(self, domain: str) -> dict: self._discovery_cache[domain] = { "metadata": metadata, "jwks": None, - "expires_at": now + self._cache_ttl, + "expires_at": now + self._cache_ttl } return metadata @@ -416,11 +413,11 @@ async def _get_jwks_cached(self, domain: str, metadata: dict = None) -> dict: if not metadata: metadata = await self._get_oidc_metadata_cached(domain) - jwks_uri = metadata.get("jwks_uri") + jwks_uri = metadata.get('jwks_uri') if not jwks_uri: raise ApiError( "missing_jwks_uri", - f"OIDC metadata for {domain} does not contain jwks_uri. Provider may be non-RFC-compliant.", + f"OIDC metadata for {domain} does not contain jwks_uri. Provider may be non-RFC-compliant." ) # Fetch JWKS @@ -437,7 +434,7 @@ async def _get_jwks_cached(self, domain: str, metadata: dict = None) -> dict: self._discovery_cache[domain] = { "metadata": metadata, "jwks": jwks, - "expires_at": now + self._cache_ttl, + "expires_at": now + self._cache_ttl } return jwks @@ -449,7 +446,9 @@ async def _get_jwks_cached(self, domain: str, metadata: dict = None) -> dict: # ============================================================================ async def start_interactive_login( - self, options: Optional[StartInteractiveLoginOptions] = None, store_options: dict = None + self, + options: Optional[StartInteractiveLoginOptions] = None, + store_options: dict = None ) -> str: """ Starts the interactive login process and returns a URL to redirect to. @@ -470,17 +469,15 @@ async def start_interactive_login( try: metadata = await self._get_oidc_metadata_cached(origin_domain) except Exception as e: - raise ApiError("metadata_error", "Failed to fetch OIDC metadata", e) + raise ApiError("metadata_error", + "Failed to fetch OIDC metadata", e) # Get effective authorization params (merge defaults with provided ones) auth_params = dict(self._default_authorization_params) if options.authorization_params: auth_params.update( - { - k: v - for k, v in options.authorization_params.items() - if k not in INTERNAL_AUTHORIZE_PARAMS - } + {k: v for k, v in options.authorization_params.items( + ) if k not in INTERNAL_AUTHORIZE_PARAMS} ) # Ensure we have a redirect_uri @@ -504,11 +501,7 @@ async def start_interactive_login( auth_params["state"] = state # Merge any requested scope with defaults - requested_scope = ( - options.authorization_params.get("scope", None) - if options.authorization_params - else None - ) + requested_scope = options.authorization_params.get("scope", None) if options.authorization_params else None audience = auth_params.get("audience", None) merged_scope = self._merge_scope_with_defaults(requested_scope, audience) auth_params["scope"] = merged_scope @@ -524,61 +517,65 @@ async def start_interactive_login( # Store the transaction data await self._transaction_store.set( - f"{self._transaction_identifier}:{state}", transaction_data, options=store_options + f"{self._transaction_identifier}:{state}", + transaction_data, + options=store_options ) # Set metadata for OAuth client self._oauth.metadata = metadata # If PAR is enabled, use the PAR endpoint if self._pushed_authorization_requests: - par_endpoint = self._oauth.metadata.get("pushed_authorization_request_endpoint") + par_endpoint = self._oauth.metadata.get( + "pushed_authorization_request_endpoint") if not par_endpoint: raise ApiError( - "configuration_error", - "PAR is enabled but pushed_authorization_request_endpoint is missing in metadata", - ) + "configuration_error", "PAR is enabled but pushed_authorization_request_endpoint is missing in metadata") auth_params["client_id"] = self._client_id # Post the auth_params to the PAR endpoint async with self._get_http_client() as client: par_response = await client.post( - par_endpoint, data=auth_params, auth=(self._client_id, self._client_secret) + par_endpoint, + data=auth_params, + auth=(self._client_id, self._client_secret) ) if par_response.status_code not in (200, 201): error_data = par_response.json() raise ApiError( error_data.get("error", "par_error"), error_data.get( - "error_description", "Failed to obtain request_uri from PAR endpoint" - ), + "error_description", "Failed to obtain request_uri from PAR endpoint") ) par_data = par_response.json() request_uri = par_data.get("request_uri") if not request_uri: - raise ApiError("par_error", "No request_uri returned from PAR endpoint") + raise ApiError( + "par_error", "No request_uri returned from PAR endpoint") auth_endpoint = self._oauth.metadata.get("authorization_endpoint") final_url = f"{auth_endpoint}?request_uri={request_uri}&response_type={auth_params['response_type']}&client_id={self._client_id}" return final_url else: if "authorization_endpoint" not in self._oauth.metadata: - raise ApiError( - "configuration_error", "Authorization endpoint missing in OIDC metadata" - ) + raise ApiError("configuration_error", + "Authorization endpoint missing in OIDC metadata") authorization_endpoint = self._oauth.metadata["authorization_endpoint"] try: auth_url, state = self._oauth.create_authorization_url( - authorization_endpoint, **auth_params - ) + authorization_endpoint, **auth_params) except Exception as e: - raise ApiError("authorization_url_error", "Failed to create authorization URL", e) + raise ApiError("authorization_url_error", + "Failed to create authorization URL", e) return auth_url async def complete_interactive_login( - self, url: str, store_options: dict = None + self, + url: str, + store_options: dict = None ) -> dict[str, Any]: """ Completes the login process after user is redirected back. @@ -601,9 +598,7 @@ async def complete_interactive_login( # Retrieve the transaction data using the state transaction_identifier = f"{self._transaction_identifier}:{state}" - transaction_data = await self._transaction_store.get( - transaction_identifier, options=store_options - ) + transaction_data = await self._transaction_store.get(transaction_identifier, options=store_options) if not transaction_data: raise MissingTransactionError() @@ -624,7 +619,7 @@ async def complete_interactive_login( # Fetch metadata and derive issuer from the origin domain metadata = await self._get_oidc_metadata_cached(origin_domain) - origin_issuer = metadata.get("issuer") + origin_issuer = metadata.get('issuer') self._oauth.metadata = metadata # Exchange the code for tokens @@ -640,7 +635,8 @@ async def complete_interactive_login( ) except OAuthError as e: # Raise a custom error (or handle it as appropriate) - raise ApiError("token_error", f"Token exchange failed: {str(e)}", e) + raise ApiError( + "token_error", f"Token exchange failed: {str(e)}", e) # Use the userinfo field from the token_response for user claims user_info = token_response.get("userinfo") @@ -655,14 +651,14 @@ async def complete_interactive_login( # Decode and verify ID token with signature verification enabled try: - claims = await self._verify_and_decode_jwt(id_token, jwks, audience=self._client_id) + claims = await self._verify_and_decode_jwt( + id_token, jwks, audience=self._client_id + ) # Custom normalized issuer validation token_issuer = claims.get("iss", "") if self._normalize_url(token_issuer) != self._normalize_url(origin_issuer): - raise IssuerValidationError( - "ID token issuer mismatch. Ensure your Auth0 domain is configured correctly." - ) + raise IssuerValidationError("ID token issuer mismatch. Ensure your Auth0 domain is configured correctly.") user_claims = UserClaims.parse_obj(claims) except ValueError as e: @@ -671,33 +667,40 @@ async def complete_interactive_login( raise ApiError( "invalid_signature", f"ID token signature verification failed. The token may have been tampered with or is from an untrusted source: {str(e)}", - e, + e ) except jwt.InvalidAudienceError as e: raise ApiError( "invalid_audience", f"ID token audience mismatch. Expected: {self._client_id}. Ensure your client_id is configured correctly: {str(e)}", - e, + e ) except jwt.ExpiredSignatureError as e: - raise ApiError("token_expired", f"ID token has expired: {str(e)}", e) + raise ApiError( + "token_expired", + f"ID token has expired: {str(e)}", + e + ) except jwt.InvalidTokenError as e: - raise ApiError("invalid_token", f"ID token verification failed: {str(e)}", e) + raise ApiError( + "invalid_token", + f"ID token verification failed: {str(e)}", + e + ) + # Build a token set using the token response data token_set = TokenSet( audience=transaction_data.audience or self.DEFAULT_AUDIENCE_STATE_KEY, access_token=token_response.get("access_token", ""), scope=token_response.get("scope", ""), - expires_at=int(time.time()) + token_response.get("expires_in", 3600), + expires_at=int(time.time()) + + token_response.get("expires_in", 3600) ) # Generate a session id (sid) from token_response or transaction data, or create a new one - sid = ( - user_info.get("sid") - if user_info and "sid" in user_info - else PKCE.generate_random_string(32) - ) + sid = user_info.get( + "sid") if user_info and "sid" in user_info else PKCE.generate_random_string(32) # Construct state data to represent the session state_data = StateData( @@ -707,7 +710,10 @@ async def complete_interactive_login( refresh_token=token_response.get("refresh_token"), token_sets=[token_set], domain=origin_domain, - internal={"sid": sid, "created_at": int(time.time())}, + internal={ + "sid": sid, + "created_at": int(time.time()) + } ) # Store the state data in the state store using store_options (Response required) @@ -732,9 +738,7 @@ async def complete_interactive_login( # Methods for retrieving user information, session data, and logout operations. # ============================================================================ - async def get_user( - self, store_options: Optional[dict[str, Any]] = None - ) -> Optional[dict[str, Any]]: + async def get_user(self, store_options: Optional[dict[str, Any]] = None) -> Optional[dict[str, Any]]: """ Retrieves the user from the store, or None if no user found. @@ -763,9 +767,7 @@ async def get_user( return state_data.get("user") return None - async def get_session( - self, store_options: Optional[dict[str, Any]] = None - ) -> Optional[dict[str, Any]]: + async def get_session(self, store_options: Optional[dict[str, Any]] = None) -> Optional[dict[str, Any]]: """ Retrieve the user session from the store, or None if no session found. @@ -791,14 +793,15 @@ async def get_session( if self._normalize_url(session_domain) != self._normalize_url(current_domain): return None - session_data = {k: v for k, v in state_data.items() if k != "internal"} + session_data = {k: v for k, v in state_data.items() + if k != "internal"} return session_data return None async def logout( self, options: Optional[LogoutOptions] = None, - store_options: Optional[dict[str, Any]] = None, + store_options: Optional[dict[str, Any]] = None ) -> str: options = options or LogoutOptions() @@ -814,18 +817,19 @@ async def logout( if hasattr(state_data, "dict") and callable(state_data.dict): state_data = state_data.dict() session_domain = self._get_session_domain(state_data) - if session_domain and self._normalize_url(session_domain) == self._normalize_url( - domain - ): + if session_domain and self._normalize_url(session_domain) == self._normalize_url(domain): await self._state_store.delete(self._state_identifier, store_options) # Return logout URL for the current resolved domain - logout_url = URL.create_logout_url(domain, self._client_id, options.return_to) + logout_url = URL.create_logout_url( + domain, self._client_id, options.return_to) return logout_url async def handle_backchannel_logout( - self, logout_token: str, store_options: Optional[dict[str, Any]] = None + self, + logout_token: str, + store_options: Optional[dict[str, Any]] = None ) -> None: """ Handles backchannel logout requests. @@ -846,7 +850,8 @@ async def handle_backchannel_logout( # Read iss from unverified token for comparison try: unverified = jwt.decode( - logout_token, algorithms=["RS256"], options={"verify_signature": False} + logout_token, algorithms=["RS256"], + options={"verify_signature": False} ) token_issuer = unverified.get("iss", "") except Exception as e: @@ -875,16 +880,13 @@ async def handle_backchannel_logout( jwks = await self._get_jwks_cached(domain) try: - claims = await self._verify_and_decode_jwt( - logout_token, jwks, audience=self._client_id - ) + claims = await self._verify_and_decode_jwt(logout_token, jwks, audience=self._client_id) # Normalized issuer validation token_issuer = claims.get("iss", "") expected_issuer = self._normalize_url(domain) if self._normalize_url(token_issuer) != self._normalize_url(expected_issuer): - raise IssuerValidationError( - "Logout token issuer mismatch.Ensure your Auth0 domain is configured correctly." + raise IssuerValidationError("Logout token issuer mismatch.Ensure your Auth0 domain is configured correctly." ) except ValueError as e: raise BackchannelLogoutError(str(e)) @@ -893,22 +895,30 @@ async def handle_backchannel_logout( f"Logout token signature verification failed: {str(e)}" ) except jwt.InvalidTokenError as e: - raise BackchannelLogoutError(f"Logout token verification failed: {str(e)}") + raise BackchannelLogoutError( + f"Logout token verification failed: {str(e)}" + ) # Validate the token is a logout token events = claims.get("events", {}) if "http://schemas.openid.net/event/backchannel-logout" not in events: - raise BackchannelLogoutError("Invalid logout token: not a backchannel logout event") + raise BackchannelLogoutError( + "Invalid logout token: not a backchannel logout event") # Delete sessions associated with this token logout_claims = LogoutTokenClaims( - sub=claims.get("sub"), sid=claims.get("sid"), iss=claims.get("iss") + sub=claims.get("sub"), + sid=claims.get("sid"), + iss=claims.get("iss") ) - await self._state_store.delete_by_logout_token(logout_claims.dict(), store_options) + await self._state_store.delete_by_logout_token( + logout_claims.dict(), store_options + ) except (jwt.PyJWTError, ValidationError) as e: - raise BackchannelLogoutError(f"Error processing logout token: {str(e)}") + raise BackchannelLogoutError( + f"Error processing logout token: {str(e)}") # ============================================================================ # ACCESS TOKEN MANAGEMENT @@ -949,13 +959,13 @@ async def get_access_token( if not session_domain: raise AccessTokenError( AccessTokenErrorCode.MISSING_SESSION_DOMAIN, - "Session domain does not match the current domain.", + "Session domain does not match the current domain." ) current_domain = await self._resolve_current_domain(store_options) if self._normalize_url(session_domain) != self._normalize_url(current_domain): raise AccessTokenError( AccessTokenErrorCode.DOMAIN_MISMATCH, - "Session domain does not match the current domain.", + "Session domain does not match the current domain." ) auth_params = self._default_authorization_params or {} @@ -969,9 +979,7 @@ async def get_access_token( # Find matching token set token_set = None if state_data_dict and "token_sets" in state_data_dict: - token_set = self._find_matching_token_set( - state_data_dict["token_sets"], audience, merged_scope - ) + token_set = self._find_matching_token_set(state_data_dict["token_sets"], audience, merged_scope) # If token is valid, return it if token_set and token_set.get("expires_at", 0) > time.time(): @@ -981,7 +989,7 @@ async def get_access_token( if not state_data_dict or not state_data_dict.get("refresh_token"): raise AccessTokenError( AccessTokenErrorCode.MISSING_REFRESH_TOKEN, - "The access token has expired and a refresh token was not provided. The user needs to re-authenticate.", + "The access token has expired and a refresh token was not provided. The user needs to re-authenticate." ) # Get new token with refresh token @@ -990,7 +998,7 @@ async def get_access_token( session_domain = state_data_dict.get("domain") or self._domain get_refresh_token_options = { "refresh_token": state_data_dict["refresh_token"], - "domain": session_domain, + "domain": session_domain } if audience: get_refresh_token_options["audience"] = audience @@ -998,20 +1006,15 @@ async def get_access_token( if merged_scope: get_refresh_token_options["scope"] = merged_scope - token_endpoint_response = await self.get_token_by_refresh_token( - get_refresh_token_options - ) + token_endpoint_response = await self.get_token_by_refresh_token(get_refresh_token_options) # Update state data with new token existing_state_data = await self._state_store.get(self._state_identifier, store_options) updated_state_data = State.update_state_data( - audience, existing_state_data, token_endpoint_response - ) + audience, existing_state_data, token_endpoint_response) # Store updated state - await self._state_store.set( - self._state_identifier, updated_state_data, options=store_options - ) + await self._state_store.set(self._state_identifier, updated_state_data, options=store_options) return token_endpoint_response["access_token"] except Exception as e: @@ -1025,21 +1028,22 @@ async def get_access_token( raw_mfa_token=raw_mfa_token, audience=audience or self.DEFAULT_AUDIENCE_STATE_KEY, scope=merged_scope or "", - mfa_requirements=mfa_requirements, + mfa_requirements=mfa_requirements ) raise MfaRequiredError( "Multifactor authentication required", mfa_token=encrypted_token, - mfa_requirements=mfa_requirements, + mfa_requirements=mfa_requirements ) if isinstance(e, AccessTokenError): raise raise AccessTokenError( AccessTokenErrorCode.REFRESH_TOKEN_ERROR, - f"Failed to get token with refresh token: {str(e)}", + f"Failed to get token with refresh token: {str(e)}" ) + async def get_token_by_refresh_token(self, options: dict[str, Any]) -> dict[str, Any]: """ Retrieves a token by exchanging a refresh token. @@ -1067,7 +1071,8 @@ async def get_token_by_refresh_token(self, options: dict[str, Any]) -> dict[str, token_endpoint = metadata.get("token_endpoint") if not token_endpoint: - raise ApiError("configuration_error", "Token endpoint missing in OIDC metadata") + raise ApiError("configuration_error", + "Token endpoint missing in OIDC metadata") # Prepare the token request parameters token_params = { @@ -1082,7 +1087,8 @@ async def get_token_by_refresh_token(self, options: dict[str, Any]) -> dict[str, # Merge scope if present in options with any in the original authorization params merged_scope = self._merge_scope_with_defaults( - request_scope=options.get("scope"), audience=audience + request_scope=options.get("scope"), + audience=audience ) if merged_scope: @@ -1091,7 +1097,9 @@ async def get_token_by_refresh_token(self, options: dict[str, Any]) -> dict[str, # Exchange the refresh token for an access token async with self._get_http_client() as client: response = await client.post( - token_endpoint, data=token_params, auth=(self._client_id, self._client_secret) + token_endpoint, + data=token_params, + auth=(self._client_id, self._client_secret) ) if response.status_code != 200: @@ -1101,7 +1109,8 @@ async def get_token_by_refresh_token(self, options: dict[str, Any]) -> dict[str, # Preserve mfa_required details for upstream handling if error_code == "mfa_required": error = ApiError( - error_code, error_data.get("error_description", "MFA required") + error_code, + error_data.get("error_description", "MFA required") ) error.mfa_token = error_data.get("mfa_token") mfa_requirements_data = error_data.get("mfa_requirements") @@ -1112,14 +1121,16 @@ async def get_token_by_refresh_token(self, options: dict[str, Any]) -> dict[str, raise ApiError( error_code, - error_data.get("error_description", "Failed to exchange refresh token"), + error_data.get("error_description", + "Failed to exchange refresh token") ) token_response = response.json() # Add required fields if they are missing if "expires_in" in token_response and "expires_at" not in token_response: - token_response["expires_at"] = int(time.time()) + token_response["expires_in"] + token_response["expires_at"] = int( + time.time()) + token_response["expires_in"] return token_response @@ -1129,11 +1140,13 @@ async def get_token_by_refresh_token(self, options: dict[str, Any]) -> dict[str, raise AccessTokenError( AccessTokenErrorCode.REFRESH_TOKEN_ERROR, "The access token has expired and there was an error while trying to refresh it.", - e, + e ) def _merge_scope_with_defaults( - self, request_scope: Optional[str], audience: Optional[str] + self, + request_scope: Optional[str], + audience: Optional[str] ) -> Optional[str]: """Helper: Merges requested scopes with default authorization params.""" audience = audience or self.DEFAULT_AUDIENCE_STATE_KEY @@ -1154,7 +1167,10 @@ def _merge_scope_with_defaults( return " ".join(merged_scopes) if merged_scopes else None def _find_matching_token_set( - self, token_sets: list[dict[str, Any]], audience: Optional[str], scope: Optional[str] + self, + token_sets: list[dict[str, Any]], + audience: Optional[str], + scope: Optional[str] ) -> Optional[dict[str, Any]]: """Helper: Finds a token set matching the requested audience and scopes.""" audience = audience or self.DEFAULT_AUDIENCE_STATE_KEY @@ -1180,7 +1196,9 @@ def _find_matching_token_set( # ============================================================================ async def login_backchannel( - self, options: dict[str, Any], store_options: Optional[dict[str, Any]] = None + self, + options: dict[str, Any], + store_options: Optional[dict[str, Any]] = None ) -> dict[str, Any]: """ Logs in using Client-Initiated Backchannel Authentication. @@ -1199,22 +1217,22 @@ async def login_backchannel( Returns: A dictionary containing the authorizationDetails (when RAR was used). """ - token_endpoint_response = await self.backchannel_authentication( - { - "binding_message": options.get("binding_message"), - "login_hint": options.get("login_hint"), - "authorization_params": options.get("authorization_params"), - }, - store_options=store_options, - ) + token_endpoint_response = await self.backchannel_authentication({ + "binding_message": options.get("binding_message"), + "login_hint": options.get("login_hint"), + "authorization_params": options.get("authorization_params"), + }, store_options=store_options) existing_state_data = await self._state_store.get(self._state_identifier, store_options) audience = self._default_authorization_params.get( - "audience", self.DEFAULT_AUDIENCE_STATE_KEY - ) + "audience", self.DEFAULT_AUDIENCE_STATE_KEY) - state_data = State.update_state_data(audience, existing_state_data, token_endpoint_response) + state_data = State.update_state_data( + audience, + existing_state_data, + token_endpoint_response + ) # Store domain for MCD session domain = await self._resolve_current_domain(store_options) @@ -1222,11 +1240,15 @@ async def login_backchannel( await self._state_store.set(self._state_identifier, state_data, store_options) - result = {"authorization_details": token_endpoint_response.get("authorization_details")} + result = { + "authorization_details": token_endpoint_response.get("authorization_details") + } return result async def backchannel_authentication( - self, options: dict[str, Any], store_options: Optional[dict[str, Any]] = None + self, + options: dict[str, Any], + store_options: Optional[dict[str, Any]] = None ) -> dict[str, Any]: """ Performs backchannel authentication with Auth0. @@ -1251,12 +1273,12 @@ async def backchannel_authentication( Raises: ApiError: If the backchannel authentication fails """ - backchannel_data = await self.initiate_backchannel_authentication( - options, store_options=store_options - ) + backchannel_data = await self.initiate_backchannel_authentication(options, store_options=store_options) auth_req_id = backchannel_data.get("auth_req_id") - expires_in = backchannel_data.get("expires_in", 120) # Default to 2 minutes - interval = backchannel_data.get("interval", 5) # Default to 5 seconds + expires_in = backchannel_data.get( + "expires_in", 120) # Default to 2 minutes + interval = backchannel_data.get( + "interval", 5) # Default to 5 seconds # Calculate when to stop polling end_time = time.time() + expires_in @@ -1265,9 +1287,7 @@ async def backchannel_authentication( while time.time() < end_time: # Make token request try: - token_response = await self.backchannel_authentication_grant( - auth_req_id, store_options=store_options - ) + token_response = await self.backchannel_authentication_grant(auth_req_id, store_options=store_options) return token_response except Exception as e: @@ -1283,14 +1303,17 @@ async def backchannel_authentication( raise ApiError( "backchannel_error", f"Backchannel authentication failed: {str(e) or 'Unknown error'}", - e, + e ) # If we get here, we've timed out - raise ApiError("timeout", "Backchannel authentication timed out") + raise ApiError( + "timeout", "Backchannel authentication timed out") async def initiate_backchannel_authentication( - self, options: dict[str, Any], store_options: Optional[dict[str, Any]] = None + self, + options: dict[str, Any], + store_options: Optional[dict[str, Any]] = None ) -> dict[str, Any]: """ Start backchannel authentication with Auth0. @@ -1320,13 +1343,18 @@ async def initiate_backchannel_authentication( https://auth0.com/docs/get-started/authentication-and-authorization-flow/client-initiated-backchannel-authentication-flow """ - sub = options.get("login_hint", {}).get("sub") + sub = options.get('login_hint', {}).get("sub") if not sub: - raise MissingRequiredArgumentError("login_hint.sub") + raise MissingRequiredArgumentError( + "login_hint.sub" + ) - authorization_params = options.get("authorization_params") + authorization_params = options.get('authorization_params') if authorization_params is not None and not isinstance(authorization_params, dict): - raise ApiError("invalid_argument", "authorization_params must be a dict") + raise ApiError( + "invalid_argument", + "authorization_params must be a dict" + ) if authorization_params: requested_expiry = authorization_params.get("requested_expiry") @@ -1334,7 +1362,7 @@ async def initiate_backchannel_authentication( if not isinstance(requested_expiry, int) or requested_expiry <= 0: raise ApiError( "invalid_argument", - "authorization_params.requested_expiry must be a positive integer", + "authorization_params.requested_expiry must be a positive integer" ) try: @@ -1343,18 +1371,24 @@ async def initiate_backchannel_authentication( metadata = await self._get_oidc_metadata_cached(domain) # Get the issuer from metadata - issuer = metadata.get("issuer") or f"https://{domain}/" + issuer = metadata.get( + "issuer") or f"https://{domain}/" # Get backchannel authentication endpoint - backchannel_endpoint = metadata.get("backchannel_authentication_endpoint") + backchannel_endpoint = metadata.get( + "backchannel_authentication_endpoint") if not backchannel_endpoint: raise ApiError( "configuration_error", - "Backchannel authentication is not supported by the authorization server", + "Backchannel authentication is not supported by the authorization server" ) # Prepare login hint in the required format - login_hint = json.dumps({"format": "iss_sub", "iss": issuer, "sub": sub}) + login_hint = json.dumps({ + "format": "iss_sub", + "iss": issuer, + "sub": sub + }) # The Request Parameters params = { @@ -1364,8 +1398,8 @@ async def initiate_backchannel_authentication( } # Add binding message if provided - if options.get("binding_message"): - params["binding_message"] = options.get("binding_message") + if options.get('binding_message'): + params["binding_message"] = options.get('binding_message') # Add any additional authorization parameters if self._default_authorization_params: @@ -1377,7 +1411,9 @@ async def initiate_backchannel_authentication( # Make the backchannel authentication request async with self._get_http_client() as client: backchannel_response = await client.post( - backchannel_endpoint, data=params, auth=(self._client_id, self._client_secret) + backchannel_endpoint, + data=params, + auth=(self._client_id, self._client_secret) ) if backchannel_response.status_code != 200: @@ -1385,8 +1421,7 @@ async def initiate_backchannel_authentication( raise ApiError( error_data.get("error", "backchannel_error"), error_data.get( - "error_description", "Backchannel authentication request failed" - ), + "error_description", "Backchannel authentication request failed") ) backchannel_data = backchannel_response.json() @@ -1395,7 +1430,7 @@ async def initiate_backchannel_authentication( if not auth_req_id: raise ApiError( "invalid_response", - "Missing auth_req_id in backchannel authentication response", + "Missing auth_req_id in backchannel authentication response" ) return backchannel_data @@ -1406,11 +1441,13 @@ async def initiate_backchannel_authentication( raise ApiError( "backchannel_error", f"Backchannel authentication failed: {str(e) or 'Unknown error'}", - e, + e ) async def backchannel_authentication_grant( - self, auth_req_id: str, store_options: Optional[dict[str, Any]] = None + self, + auth_req_id: str, + store_options: Optional[dict[str, Any]] = None ) -> dict[str, Any]: """ Retrieves a token by exchanging an auth_req_id. @@ -1435,20 +1472,23 @@ async def backchannel_authentication_grant( token_endpoint = metadata.get("token_endpoint") if not token_endpoint: - raise ApiError("configuration_error", "Token endpoint missing in OIDC metadata") + raise ApiError("configuration_error", + "Token endpoint missing in OIDC metadata") # Prepare the token request parameters token_params = { "grant_type": "urn:openid:params:grant-type:ciba", "auth_req_id": auth_req_id, "client_id": self._client_id, - "client_secret": self._client_secret, + "client_secret": self._client_secret } # Exchange the auth_req_id for an access token async with self._get_http_client() as client: response = await client.post( - token_endpoint, data=token_params, auth=(self._client_id, self._client_secret) + token_endpoint, + data=token_params, + auth=(self._client_id, self._client_secret) ) if response.status_code != 200: @@ -1457,18 +1497,23 @@ async def backchannel_authentication_grant( interval = int(retry_after) if retry_after is not None else None raise PollingApiError( error_data.get("error", "auth_req_id_error"), - error_data.get("error_description", "Failed to exchange auth_req_id"), - interval, + error_data.get("error_description", + "Failed to exchange auth_req_id"), + interval ) try: token_response = response.json() except json.JSONDecodeError: - raise ApiError("invalid_response", "Failed to parse token response as JSON") + raise ApiError( + "invalid_response", + "Failed to parse token response as JSON" + ) # Add required fields if they are missing if "expires_in" in token_response and "expires_at" not in token_response: - token_response["expires_at"] = int(time.time()) + token_response["expires_in"] + token_response["expires_at"] = int( + time.time()) + token_response["expires_in"] return token_response @@ -1478,7 +1523,7 @@ async def backchannel_authentication_grant( raise AccessTokenError( AccessTokenErrorCode.AUTH_REQ_ID_ERROR, "There was an error while trying to exchange the auth_req_id for an access token.", - e, + e ) # ============================================================================ @@ -1487,7 +1532,11 @@ async def backchannel_authentication_grant( # to a user's Auth0 profile. # ============================================================================ - async def start_link_user(self, options, store_options: Optional[dict[str, Any]] = None): + async def start_link_user( + self, + options, + store_options: Optional[dict[str, Any]] = None + ): """ Starts the user linking process, and returns a URL to redirect the user-agent to. @@ -1514,9 +1563,13 @@ async def start_link_user(self, options, store_options: Optional[dict[str, Any]] state_data = state_data.dict() session_domain = self._get_session_domain(state_data) if not session_domain: - raise StartLinkUserError("Session domain does not match the current domain.") + raise StartLinkUserError( + "Session domain does not match the current domain." + ) if self._normalize_url(session_domain) != self._normalize_url(origin_domain): - raise StartLinkUserError("Session domain does not match the current domain.") + raise StartLinkUserError( + "Session domain does not match the current domain." + ) # Generate PKCE and state for security code_verifier = PKCE.generate_code_verifier() @@ -1530,7 +1583,7 @@ async def start_link_user(self, options, store_options: Optional[dict[str, Any]] code_verifier=code_verifier, state=state, authorization_params=options.get("authorization_params"), - domain=origin_domain, + domain=origin_domain ) # Store transaction data @@ -1541,13 +1594,17 @@ async def start_link_user(self, options, store_options: Optional[dict[str, Any]] ) await self._transaction_store.set( - f"{self._transaction_identifier}:{state}", transaction_data, options=store_options + f"{self._transaction_identifier}:{state}", + transaction_data, + options=store_options ) return link_user_url async def complete_link_user( - self, url: str, store_options: Optional[dict[str, Any]] = None + self, + url: str, + store_options: Optional[dict[str, Any]] = None ) -> dict[str, Any]: """ Completes the user linking process. @@ -1564,9 +1621,15 @@ async def complete_link_user( result = await self.complete_interactive_login(url, store_options) # Return just the app state as specified - return {"app_state": result.get("app_state")} + return { + "app_state": result.get("app_state") + } - async def start_unlink_user(self, options, store_options: Optional[dict[str, Any]] = None): + async def start_unlink_user( + self, + options, + store_options: Optional[dict[str, Any]] = None + ): """ Starts the user unlinking process, and returns a URL to redirect the user-agent to. @@ -1593,9 +1656,13 @@ async def start_unlink_user(self, options, store_options: Optional[dict[str, Any state_data = state_data.dict() session_domain = self._get_session_domain(state_data) if not session_domain: - raise StartLinkUserError("Session domain does not match the current domain.") + raise StartLinkUserError( + "Session domain does not match the current domain." + ) if self._normalize_url(session_domain) != self._normalize_url(origin_domain): - raise StartLinkUserError("Session domain does not match the current domain.") + raise StartLinkUserError( + "Session domain does not match the current domain." + ) # Generate PKCE and state for security code_verifier = PKCE.generate_code_verifier() @@ -1608,7 +1675,7 @@ async def start_unlink_user(self, options, store_options: Optional[dict[str, Any code_verifier=code_verifier, state=state, authorization_params=options.get("authorization_params"), - domain=origin_domain, + domain=origin_domain ) # Store transaction data @@ -1619,13 +1686,17 @@ async def start_unlink_user(self, options, store_options: Optional[dict[str, Any ) await self._transaction_store.set( - f"{self._transaction_identifier}:{state}", transaction_data, options=store_options + f"{self._transaction_identifier}:{state}", + transaction_data, + options=store_options ) return link_user_url async def complete_unlink_user( - self, url: str, store_options: Optional[dict[str, Any]] = None + self, + url: str, + store_options: Optional[dict[str, Any]] = None ) -> dict[str, Any]: """ Completes the user unlinking process. @@ -1642,7 +1713,9 @@ async def complete_unlink_user( result = await self.complete_interactive_login(url, store_options) # Return just the app state as specified - return {"app_state": result.get("app_state")} + return { + "app_state": result.get("app_state") + } async def _build_link_user_url( self, @@ -1652,7 +1725,7 @@ async def _build_link_user_url( state: str, connection_scope: Optional[str] = None, authorization_params: Optional[dict[str, Any]] = None, - domain: Optional[str] = None, + domain: Optional[str] = None ) -> str: """Build a URL for linking user accounts""" # Generate code challenge from verifier @@ -1663,9 +1736,8 @@ async def _build_link_user_url( metadata = await self._get_oidc_metadata_cached(resolved_domain) # Get authorization endpoint - auth_endpoint = metadata.get( - "authorization_endpoint", f"https://{resolved_domain}/authorize" - ) + auth_endpoint = metadata.get("authorization_endpoint", + f"https://{resolved_domain}/authorize") # Build params params = { @@ -1678,7 +1750,7 @@ async def _build_link_user_url( "response_type": "code", "id_token_hint": id_token, "scope": "openid link_account", - "audience": "my-account", + "audience": "my-account" } # Add connection scope if provided @@ -1697,7 +1769,7 @@ async def _build_unlink_user_url( code_verifier: str, state: str, authorization_params: Optional[dict[str, Any]] = None, - domain: Optional[str] = None, + domain: Optional[str] = None ) -> str: """Build a URL for unlinking user accounts""" # Generate code challenge from verifier @@ -1708,9 +1780,8 @@ async def _build_unlink_user_url( metadata = await self._get_oidc_metadata_cached(resolved_domain) # Get authorization endpoint - auth_endpoint = metadata.get( - "authorization_endpoint", f"https://{resolved_domain}/authorize" - ) + auth_endpoint = metadata.get("authorization_endpoint", + f"https://{resolved_domain}/authorize") # Build params params = { @@ -1722,7 +1793,7 @@ async def _build_unlink_user_url( "response_type": "code", "id_token_hint": id_token, "scope": "openid unlink_account", - "audience": "my-account", + "audience": "my-account" } # Add any additional parameters if authorization_params: @@ -1737,7 +1808,9 @@ async def _build_unlink_user_url( # ============================================================================ async def get_access_token_for_connection( - self, options: dict[str, Any], store_options: Optional[dict[str, Any]] = None + self, + options: dict[str, Any], + store_options: Optional[dict[str, Any]] = None ) -> str: """ Retrieves an access token for a connection. @@ -1771,13 +1844,13 @@ async def get_access_token_for_connection( if not session_domain: raise AccessTokenForConnectionError( AccessTokenForConnectionErrorCode.MISSING_SESSION_DOMAIN, - "Session domain does not match the current domain.", + "Session domain does not match the current domain." ) current_domain = await self._resolve_current_domain(store_options) if self._normalize_url(session_domain) != self._normalize_url(current_domain): raise AccessTokenForConnectionError( AccessTokenForConnectionErrorCode.DOMAIN_MISMATCH, - "Session domain does not match the current domain.", + "Session domain does not match the current domain." ) # Find existing connection token @@ -1796,24 +1869,21 @@ async def get_access_token_for_connection( if not state_data_dict or not state_data_dict.get("refresh_token"): raise AccessTokenForConnectionError( AccessTokenForConnectionErrorCode.MISSING_REFRESH_TOKEN, - "A refresh token was not found but is required to be able to retrieve an access token for a connection.", + "A refresh token was not found but is required to be able to retrieve an access token for a connection." ) # Get new token for connection # Use session's domain for token exchange session_domain = state_data_dict.get("domain") or self._domain - token_endpoint_response = await self.get_token_for_connection( - { - "connection": options.get("connection"), - "login_hint": options.get("login_hint"), - "refresh_token": state_data_dict["refresh_token"], - "domain": session_domain, - } - ) + token_endpoint_response = await self.get_token_for_connection({ + "connection": options.get("connection"), + "login_hint": options.get("login_hint"), + "refresh_token": state_data_dict["refresh_token"], + "domain": session_domain + }) # Update state data with new token updated_state_data = State.update_state_data_for_connection_token_set( - options, state_data_dict, token_endpoint_response - ) + options, state_data_dict, token_endpoint_response) # Store updated state await self._state_store.set(self._state_identifier, updated_state_data, store_options) @@ -1837,12 +1907,8 @@ async def get_token_for_connection(self, options: dict[str, Any]) -> dict[str, A """ # Constants SUBJECT_TYPE_REFRESH_TOKEN = "urn:ietf:params:oauth:token-type:refresh_token" - REQUESTED_TOKEN_TYPE_FEDERATED_CONNECTION_ACCESS_TOKEN = ( - "http://auth0.com/oauth/token-type/federated-connection-access-token" - ) - GRANT_TYPE_FEDERATED_CONNECTION_ACCESS_TOKEN = ( - "urn:auth0:params:oauth:grant-type:token-exchange:federated-connection-access-token" - ) + REQUESTED_TOKEN_TYPE_FEDERATED_CONNECTION_ACCESS_TOKEN = "http://auth0.com/oauth/token-type/federated-connection-access-token" + GRANT_TYPE_FEDERATED_CONNECTION_ACCESS_TOKEN = "urn:auth0:params:oauth:grant-type:token-exchange:federated-connection-access-token" try: # Use session domain if provided, otherwise fallback to static domain domain = options.get("domain") or self._domain @@ -1852,7 +1918,8 @@ async def get_token_for_connection(self, options: dict[str, Any]) -> dict[str, A token_endpoint = metadata.get("token_endpoint") if not token_endpoint: - raise ApiError("configuration_error", "Token endpoint missing in OIDC metadata") + raise ApiError("configuration_error", + "Token endpoint missing in OIDC metadata") # Prepare parameters params = { @@ -1861,7 +1928,7 @@ async def get_token_for_connection(self, options: dict[str, Any]) -> dict[str, A "subject_token": options["refresh_token"], "requested_token_type": REQUESTED_TOKEN_TYPE_FEDERATED_CONNECTION_ACCESS_TOKEN, "grant_type": GRANT_TYPE_FEDERATED_CONNECTION_ACCESS_TOKEN, - "client_id": self._client_id, + "client_id": self._client_id } # Add login_hint if provided @@ -1871,41 +1938,38 @@ async def get_token_for_connection(self, options: dict[str, Any]) -> dict[str, A # Make the request async with self._get_http_client() as client: response = await client.post( - token_endpoint, data=params, auth=(self._client_id, self._client_secret) + token_endpoint, + data=params, + auth=(self._client_id, self._client_secret) ) if response.status_code != 200: - error_data = ( - response.json() - if response.headers.get("content-type") == "application/json" - else {} - ) + error_data = response.json() if response.headers.get( + "content-type") == "application/json" else {} raise ApiError( error_data.get("error", "connection_token_error"), error_data.get( - "error_description", - f"Failed to get token for connection: {response.status_code}", - ), + "error_description", f"Failed to get token for connection: {response.status_code}") ) token_endpoint_response = response.json() return { "access_token": token_endpoint_response.get("access_token"), - "expires_at": int(time.time()) - + int(token_endpoint_response.get("expires_in", 3600)), - "scope": token_endpoint_response.get("scope", ""), + "expires_at": int(time.time()) + int(token_endpoint_response.get("expires_in", 3600)), + "scope": token_endpoint_response.get("scope", "") } except Exception as e: if isinstance(e, ApiError): raise AccessTokenForConnectionError( - AccessTokenForConnectionErrorCode.API_ERROR, str(e) + AccessTokenForConnectionErrorCode.API_ERROR, + str(e) ) raise AccessTokenForConnectionError( AccessTokenForConnectionErrorCode.FETCH_ERROR, "There was an error while trying to retrieve an access token for a connection.", - e, + e ) # ============================================================================ @@ -1915,7 +1979,9 @@ async def get_token_for_connection(self, options: dict[str, Any]) -> dict[str, A # ============================================================================ async def start_connect_account( - self, options: ConnectAccountOptions, store_options: dict = None + self, + options: ConnectAccountOptions, + store_options: dict = None ) -> str: """ Initiates the connect account flow for linking a third-party account to the user's profile. @@ -1940,25 +2006,26 @@ async def start_connect_account( code_verifier = PKCE.generate_code_verifier() code_challenge = PKCE.generate_code_challenge(code_verifier) - state = PKCE.generate_random_string(32) + state= PKCE.generate_random_string(32) connect_request = ConnectAccountRequest( connection=options.connection, scopes=options.scopes, - redirect_uri=redirect_uri, + redirect_uri = redirect_uri, code_challenge=code_challenge, code_challenge_method="S256", state=state, - authorization_params=options.authorization_params, + authorization_params=options.authorization_params ) access_token = await self.get_access_token( audience=self._my_account_client.audience, scope="create:me:connected_accounts", - store_options=store_options, + store_options=store_options ) connect_response = await self._my_account_client.connect_account( - access_token=access_token, request=connect_request + access_token=access_token, + request=connect_request ) # Build the transaction data to store @@ -1966,29 +2033,24 @@ async def start_connect_account( code_verifier=code_verifier, app_state=options.app_state, auth_session=connect_response.auth_session, - redirect_uri=redirect_uri, + redirect_uri=redirect_uri ) # Store the transaction data await self._transaction_store.set( - f"{self._transaction_identifier}:{state}", transaction_data, options=store_options + f"{self._transaction_identifier}:{state}", + transaction_data, + options=store_options ) parsedUrl = urlparse(connect_response.connect_uri) query = urlencode({"ticket": connect_response.connect_params.ticket}) - return urlunparse( - ( - parsedUrl.scheme, - parsedUrl.netloc, - parsedUrl.path, - parsedUrl.params, - query, - parsedUrl.fragment, - ) - ) + return urlunparse((parsedUrl.scheme, parsedUrl.netloc, parsedUrl.path, parsedUrl.params, query, parsedUrl.fragment)) async def complete_connect_account( - self, url: str, store_options: dict = None + self, + url: str, + store_options: dict = None ) -> CompleteConnectAccountResponse: """ Handles the redirect callback to complete the connect account flow for linking a third-party @@ -2020,9 +2082,7 @@ async def complete_connect_account( # Retrieve the transaction data using the state transaction_identifier = f"{self._transaction_identifier}:{state}" - transaction_data = await self._transaction_store.get( - transaction_identifier, options=store_options - ) + transaction_data = await self._transaction_store.get(transaction_identifier, options=store_options) if not transaction_data: raise MissingTransactionError() @@ -2030,19 +2090,18 @@ async def complete_connect_account( access_token = await self.get_access_token( audience=self._my_account_client.audience, scope="create:me:connected_accounts", - store_options=store_options, + store_options=store_options ) request = CompleteConnectAccountRequest( auth_session=transaction_data.auth_session, connect_code=connect_code, redirect_uri=transaction_data.redirect_uri, - code_verifier=transaction_data.code_verifier, + code_verifier=transaction_data.code_verifier ) try: response = await self._my_account_client.complete_connect_account( - access_token=access_token, request=request - ) + access_token=access_token, request=request) if transaction_data.app_state is not None: response.app_state = transaction_data.app_state finally: @@ -2056,7 +2115,7 @@ async def list_connected_accounts( connection: Optional[str] = None, from_param: Optional[str] = None, take: Optional[int] = None, - store_options: dict = None, + store_options: dict = None ) -> ListConnectedAccountsResponse: """ Retrieves a list of connected accounts for the authenticated user. @@ -2080,14 +2139,15 @@ async def list_connected_accounts( access_token = await self.get_access_token( audience=self._my_account_client.audience, scope="read:me:connected_accounts", - store_options=store_options, + store_options=store_options ) return await self._my_account_client.list_connected_accounts( - access_token=access_token, connection=connection, from_param=from_param, take=take - ) + access_token=access_token, connection=connection, from_param=from_param, take=take) async def delete_connected_account( - self, connected_account_id: str, store_options: dict = None + self, + connected_account_id: str, + store_options: dict = None ) -> None: """ Deletes a connected account. @@ -2106,17 +2166,16 @@ async def delete_connected_account( access_token = await self.get_access_token( audience=self._my_account_client.audience, scope="delete:me:connected_accounts", - store_options=store_options, + store_options=store_options ) await self._my_account_client.delete_connected_account( - access_token=access_token, connected_account_id=connected_account_id - ) + access_token=access_token, connected_account_id=connected_account_id) async def list_connected_account_connections( self, from_param: Optional[str] = None, take: Optional[int] = None, - store_options: dict = None, + store_options: dict = None ) -> ListConnectedAccountConnectionsResponse: """ Retrieves a list of available connections that can be used connected accounts for the authenticated user. @@ -2139,11 +2198,10 @@ async def list_connected_account_connections( access_token = await self.get_access_token( audience=self._my_account_client.audience, scope="read:me:connected_accounts", - store_options=store_options, + store_options=store_options ) return await self._my_account_client.list_connected_account_connections( - access_token=access_token, from_param=from_param, take=take - ) + access_token=access_token, from_param=from_param, take=take) # ============================================================================ # CUSTOM TOKEN EXCHANGE (RFC 8693) @@ -2151,7 +2209,9 @@ async def list_connected_account_connections( # ============================================================================ async def custom_token_exchange( - self, options: CustomTokenExchangeOptions, store_options: Optional[dict[str, Any]] = None + self, + options: CustomTokenExchangeOptions, + store_options: Optional[dict[str, Any]] = None ) -> TokenExchangeResponse: """ Exchanges a custom token for Auth0 tokens using RFC 8693. @@ -2224,12 +2284,7 @@ async def custom_token_exchange( # Merge additional authorization params if options.authorization_params: # Prevent override of critical parameters - forbidden_params = { - "grant_type", - "client_id", - "subject_token", - "subject_token_type", - } + forbidden_params = {"grant_type", "client_id", "subject_token", "subject_token_type"} for key, value in options.authorization_params.items(): if key not in forbidden_params: params[key] = value @@ -2237,20 +2292,17 @@ async def custom_token_exchange( # Make the token exchange request async with self._get_http_client() as client: response = await client.post( - token_endpoint, data=params, auth=(self._client_id, self._client_secret) + token_endpoint, + data=params, + auth=(self._client_id, self._client_secret) ) if response.status_code != 200: - error_data = ( - response.json() - if response.headers.get("content-type", "").startswith("application/json") - else {} - ) + error_data = response.json() if response.headers.get( + "content-type", "").startswith("application/json") else {} raise CustomTokenExchangeError( error_data.get("error", CustomTokenExchangeErrorCode.TOKEN_EXCHANGE_FAILED), - error_data.get( - "error_description", f"Token exchange failed: {response.status_code}" - ), + error_data.get("error_description", f"Token exchange failed: {response.status_code}") ) try: @@ -2258,7 +2310,7 @@ async def custom_token_exchange( except json.JSONDecodeError: raise CustomTokenExchangeError( CustomTokenExchangeErrorCode.INVALID_RESPONSE, - "Failed to parse token response as JSON", + "Failed to parse token response as JSON" ) # Validate and return response @@ -2267,7 +2319,7 @@ async def custom_token_exchange( except ValidationError as e: raise CustomTokenExchangeError( CustomTokenExchangeErrorCode.INVALID_TOKEN_FORMAT, - f"Token validation failed: {str(e)}", + f"Token validation failed: {str(e)}" ) except Exception as e: if isinstance(e, (CustomTokenExchangeError, ApiError)): @@ -2275,13 +2327,13 @@ async def custom_token_exchange( raise CustomTokenExchangeError( CustomTokenExchangeErrorCode.TOKEN_EXCHANGE_FAILED, f"Token exchange failed: {str(e)}", - e, + e ) async def login_with_custom_token_exchange( self, options: LoginWithCustomTokenExchangeOptions, - store_options: Optional[dict[str, Any]] = None, + store_options: Optional[dict[str, Any]] = None ) -> LoginWithCustomTokenExchangeResult: """ Performs token exchange and establishes a user session. @@ -2326,12 +2378,10 @@ async def login_with_custom_token_exchange( actor_token=options.actor_token, actor_token_type=options.actor_token_type, organization=options.organization, - authorization_params=options.authorization_params, + authorization_params=options.authorization_params ) - token_response = await self.custom_token_exchange( - exchange_options, store_options=store_options - ) + token_response = await self.custom_token_exchange(exchange_options, store_options=store_options) # Resolve domain and fetch metadata for verification domain = await self._resolve_current_domain(store_options) @@ -2363,18 +2413,28 @@ async def login_with_custom_token_exchange( raise ApiError("jwks_key_not_found", str(e)) except jwt.InvalidSignatureError as e: raise ApiError( - "invalid_signature", f"ID token signature verification failed: {str(e)}", e + "invalid_signature", + f"ID token signature verification failed: {str(e)}", + e ) except jwt.InvalidAudienceError as e: raise ApiError( "invalid_audience", f"ID token audience mismatch. Expected: {self._client_id}: {str(e)}", - e, + e ) except jwt.ExpiredSignatureError as e: - raise ApiError("token_expired", f"ID token has expired: {str(e)}", e) + raise ApiError( + "token_expired", + f"ID token has expired: {str(e)}", + e + ) except jwt.InvalidTokenError as e: - raise ApiError("invalid_token", f"ID token verification failed: {str(e)}", e) + raise ApiError( + "invalid_token", + f"ID token verification failed: {str(e)}", + e + ) # Determine audience for token set audience = options.audience or self.DEFAULT_AUDIENCE_STATE_KEY @@ -2384,7 +2444,7 @@ async def login_with_custom_token_exchange( audience=audience, access_token=token_response.access_token, scope=token_response.scope or options.scope or "", - expires_at=int(time.time()) + token_response.expires_in, + expires_at=int(time.time()) + token_response.expires_in ) # Construct state data @@ -2394,14 +2454,19 @@ async def login_with_custom_token_exchange( refresh_token=token_response.refresh_token, token_sets=[token_set], domain=domain, - internal={"sid": sid, "created_at": int(time.time())}, + internal={ + "sid": sid, + "created_at": int(time.time()) + } ) # Store session await self._state_store.set(self._state_identifier, state_data, options=store_options) # Build result - result = LoginWithCustomTokenExchangeResult(state_data=state_data.dict()) + result = LoginWithCustomTokenExchangeResult( + state_data=state_data.dict() + ) return result @@ -2411,11 +2476,11 @@ async def login_with_custom_token_exchange( raise CustomTokenExchangeError( CustomTokenExchangeErrorCode.TOKEN_EXCHANGE_FAILED, f"Login with custom token exchange failed: {str(e)}", - e, + e ) # ============================================================================ - # PASSKEY AUTHENTICATION (Category 1) + # PASSKEY AUTHENTICATION # ============================================================================ GRANT_TYPE_PASSKEY = "urn:okta:params:oauth:grant-type:webauthn" From 66071e9698e597d22d58ea636be96ddfdc2e8792 Mon Sep 17 00:00:00 2001 From: Sourav Basu Date: Tue, 2 Jun 2026 17:42:43 +0530 Subject: [PATCH 6/9] Edge case fix for Double URL-encoding and extra validation check --- .../auth_schemes/dpop_auth.py | 4 ++++ .../auth_server/my_account_client.py | 20 +++++++++---------- .../auth_types/__init__.py | 6 +++--- 3 files changed, 16 insertions(+), 14 deletions(-) diff --git a/src/auth0_server_python/auth_schemes/dpop_auth.py b/src/auth0_server_python/auth_schemes/dpop_auth.py index 1517a78..d10b8f3 100644 --- a/src/auth0_server_python/auth_schemes/dpop_auth.py +++ b/src/auth0_server_python/auth_schemes/dpop_auth.py @@ -17,6 +17,10 @@ def __init__(self, token: str, key: "jwk.JWK") -> None: public_jwk = key.export_public(as_dict=True) if public_jwk.get("kty") != "EC" or public_jwk.get("crv") != "P-256": raise ValueError("DPoP key must be an EC P-256 key") + try: + token.encode("ascii") + except UnicodeEncodeError: + raise ValueError("Access token must contain only ASCII characters") self._token = token self._key = key self._public_jwk = public_jwk diff --git a/src/auth0_server_python/auth_server/my_account_client.py b/src/auth0_server_python/auth_server/my_account_client.py index bd23a12..e5ef646 100644 --- a/src/auth0_server_python/auth_server/my_account_client.py +++ b/src/auth0_server_python/auth_server/my_account_client.py @@ -1,10 +1,8 @@ import json from typing import TYPE_CHECKING, Optional -from urllib.parse import quote +from urllib.parse import quote, unquote import httpx -from pydantic import ValidationError - from auth0_server_python.auth_schemes.bearer_auth import BearerAuth from auth0_server_python.auth_schemes.dpop_auth import DPoPAuth from auth0_server_python.auth_types import ( @@ -416,7 +414,7 @@ async def get_factors( return GetFactorsResponse.model_validate(response.json()) except Exception as e: - if isinstance(e, (MyAccountApiError, ApiError, ValidationError)): + if isinstance(e, (MyAccountApiError, ApiError)): raise raise ApiError( "get_factors_error", @@ -464,7 +462,7 @@ async def list_authentication_methods( return ListAuthenticationMethodsResponse.model_validate(response.json()) except Exception as e: - if isinstance(e, (MyAccountApiError, ApiError, ValidationError)): + if isinstance(e, (MyAccountApiError, ApiError)): raise raise ApiError( "list_authentication_methods_error", @@ -509,7 +507,7 @@ async def get_authentication_method( return AuthenticationMethod.model_validate(response.json()) except Exception as e: - if isinstance(e, (MyAccountApiError, ApiError, ValidationError)): + if isinstance(e, (MyAccountApiError, ApiError)): raise raise ApiError( "get_authentication_method_error", @@ -552,7 +550,7 @@ async def delete_authentication_method( ) except Exception as e: - if isinstance(e, (MyAccountApiError, ApiError, ValidationError)): + if isinstance(e, (MyAccountApiError, ApiError)): raise raise ApiError( "delete_authentication_method_error", @@ -601,7 +599,7 @@ async def update_authentication_method( return AuthenticationMethod.model_validate(response.json()) except Exception as e: - if isinstance(e, (MyAccountApiError, ApiError, ValidationError)): + if isinstance(e, (MyAccountApiError, ApiError)): raise raise ApiError( "update_authentication_method_error", @@ -661,7 +659,7 @@ async def enroll_authentication_method( path = location.split("?")[0].split("#")[0].rstrip("/") segments = path.split("/") - authentication_method_id = segments[-1] if len(segments) > 1 else "" + authentication_method_id = unquote(segments[-1]) if len(segments) > 1 else "" if not authentication_method_id or authentication_method_id in ( "authentication-methods", "v1", @@ -696,7 +694,7 @@ async def enroll_authentication_method( ) except Exception as e: - if isinstance(e, (MyAccountApiError, ApiError, ValidationError)): + if isinstance(e, (MyAccountApiError, ApiError)): raise raise ApiError( "enroll_authentication_method_error", @@ -749,7 +747,7 @@ async def verify_authentication_method( return AuthenticationMethod.model_validate(response.json()) except Exception as e: - if isinstance(e, (MyAccountApiError, ApiError, ValidationError)): + if isinstance(e, (MyAccountApiError, ApiError)): raise raise ApiError( "verify_authentication_method_error", diff --git a/src/auth0_server_python/auth_types/__init__.py b/src/auth0_server_python/auth_types/__init__.py index d306efa..8141a18 100644 --- a/src/auth0_server_python/auth_types/__init__.py +++ b/src/auth0_server_python/auth_types/__init__.py @@ -553,9 +553,9 @@ class VerifyAuthenticationMethodRequest(BaseModel): def _check_at_least_one_method(self) -> "VerifyAuthenticationMethodRequest": has_method = ( self.authn_response is not None - or self.otp_code is not None - or self.recovery_code is not None - or self.password is not None + or (self.otp_code is not None and self.otp_code.strip() != "") + or (self.recovery_code is not None and self.recovery_code.strip() != "") + or (self.password is not None and self.password.strip() != "") ) if not has_method: raise ValueError( From 3809edb21b3c0f7df4cbb15fffc5c9a167340856 Mon Sep 17 00:00:00 2001 From: Sourav Basu Date: Fri, 5 Jun 2026 12:29:33 +0530 Subject: [PATCH 7/9] SDK-8780 PR review changes --- .../auth_schemes/dpop_auth.py | 38 +- .../auth_server/my_account_client.py | 8 +- .../auth_server/server_client.py | 158 ++-- .../auth_types/__init__.py | 388 ++++---- src/auth0_server_python/error/__init__.py | 18 + .../tests/test_my_account_client.py | 835 ++++++++++++++++++ .../tests/test_passkey_my_account.py | 830 ----------------- .../tests/test_passkey_server_client.py | 585 ------------ .../tests/test_server_client.py | 794 +++++++++++++++++ 9 files changed, 1956 insertions(+), 1698 deletions(-) delete mode 100644 src/auth0_server_python/tests/test_passkey_my_account.py delete mode 100644 src/auth0_server_python/tests/test_passkey_server_client.py diff --git a/src/auth0_server_python/auth_schemes/dpop_auth.py b/src/auth0_server_python/auth_schemes/dpop_auth.py index d10b8f3..0bf2d66 100644 --- a/src/auth0_server_python/auth_schemes/dpop_auth.py +++ b/src/auth0_server_python/auth_schemes/dpop_auth.py @@ -12,6 +12,28 @@ def _base64url(data: bytes) -> str: return base64.urlsafe_b64encode(data).rstrip(b"=").decode("ascii") +def make_dpop_proof_for_token_endpoint(key: "jwk.JWK", method: str, url: str, nonce: str = None) -> str: + """ + Build a DPoP proof JWT for use at the token endpoint (RFC 9449 §4.2). + Unlike resource-server proofs, token-endpoint proofs do NOT include `ath` + because no access token exists yet at issuance time. + """ + public_jwk = key.export_public(as_dict=True) + htu = url.split("?")[0].split("#")[0] + header = {"typ": "dpop+jwt", "alg": "ES256", "jwk": public_jwk} + payload = { + "jti": str(uuid.uuid4()), + "htm": method.upper(), + "htu": htu, + "iat": int(time.time()), + } + if nonce is not None: + payload["nonce"] = nonce + token = jwcrypto_jwt.JWT(header=header, claims=payload) + token.make_signed_token(key) + return token.serialize() + + class DPoPAuth(httpx.Auth): def __init__(self, token: str, key: "jwk.JWK") -> None: public_jwk = key.export_public(as_dict=True) @@ -35,9 +57,19 @@ def auth_flow(self, request: httpx.Request): proof = self._make_proof(request.method, str(request.url)) request.headers["Authorization"] = f"DPoP {self._token}" request.headers["DPoP"] = proof - yield request + response = yield request + + # RFC 9449 §8.2 — server-nonce retry + if ( + response is not None + and response.status_code == 401 + and response.headers.get("DPoP-Nonce") + ): + nonce = response.headers["DPoP-Nonce"] + request.headers["DPoP"] = self._make_proof(request.method, str(request.url), nonce=nonce) + yield request - def _make_proof(self, method: str, url: str) -> str: + def _make_proof(self, method: str, url: str, nonce: str = None) -> str: htu = url.split("?")[0].split("#")[0] ath = _base64url(hashlib.sha256(self._token.encode("ascii")).digest()) @@ -49,6 +81,8 @@ def _make_proof(self, method: str, url: str) -> str: "iat": int(time.time()), "ath": ath, } + if nonce is not None: + payload["nonce"] = nonce token = jwcrypto_jwt.JWT(header=header, claims=payload) token.make_signed_token(self._key) diff --git a/src/auth0_server_python/auth_server/my_account_client.py b/src/auth0_server_python/auth_server/my_account_client.py index e5ef646..5ffadd9 100644 --- a/src/auth0_server_python/auth_server/my_account_client.py +++ b/src/auth0_server_python/auth_server/my_account_client.py @@ -619,7 +619,7 @@ async def enroll_authentication_method( navigator.credentials.create(), then call verify_authentication_method() with the auth_session and credential result. - Requires scope: create:me:authentication_methods + Requires scope: create:me:authentication-methods """ if not access_token: raise MissingRequiredArgumentError("access_token") @@ -634,7 +634,7 @@ async def enroll_authentication_method( auth=_make_auth(access_token, dpop_key), ) - if response.status_code != 201: + if response.status_code != 202: try: error_data = response.json() except (json.JSONDecodeError, ValueError): @@ -711,7 +711,7 @@ async def verify_authentication_method( ) -> AuthenticationMethod: """Step 2 of 2: Verify enrollment (POST /me/v1/authentication-methods/{id}/verify). - Requires scope: create:me:authentication_methods + Requires scope: create:me:authentication-methods """ if not access_token: raise MissingRequiredArgumentError("access_token") @@ -728,7 +728,7 @@ async def verify_authentication_method( auth=_make_auth(access_token, dpop_key), ) - if response.status_code != 200: + if response.status_code != 201: try: error_data = response.json() except (json.JSONDecodeError, ValueError): diff --git a/src/auth0_server_python/auth_server/server_client.py b/src/auth0_server_python/auth_server/server_client.py index a233205..110f33a 100644 --- a/src/auth0_server_python/auth_server/server_client.py +++ b/src/auth0_server_python/auth_server/server_client.py @@ -7,12 +7,16 @@ import json import time from collections import OrderedDict -from typing import Any, Callable, Generic, Optional, TypeVar, Union +from typing import TYPE_CHECKING, Any, Callable, Generic, Optional, TypeVar, Union + +if TYPE_CHECKING: + from jwcrypto import jwk from urllib.parse import parse_qs, urlencode, urlparse, urlunparse import httpx import jwt from authlib.integrations.base_client.errors import OAuthError +from auth0_server_python.auth_schemes.dpop_auth import DPoPAuth, make_dpop_proof_for_token_endpoint from authlib.integrations.httpx_client import AsyncOAuth2Client from pydantic import ValidationError @@ -34,6 +38,7 @@ PasskeyAuthResponse, PasskeyLoginChallengeResponse, PasskeySignupChallengeResponse, + PasskeyUserProfile, PasskeyTokenResponse, StartInteractiveLoginOptions, StateData, @@ -58,6 +63,8 @@ MfaRequiredError, MissingRequiredArgumentError, MissingTransactionError, + PasskeyError, + PasskeyErrorCode, PollingApiError, StartLinkUserError, ) @@ -82,6 +89,9 @@ class ServerClient(Generic[TStoreOptions]): and token operations using Authlib for OIDC functionality. """ DEFAULT_AUDIENCE_STATE_KEY = "default" + GRANT_TYPE_PASSKEY = "urn:okta:params:oauth:grant-type:webauthn" + PASSKEY_REGISTER_PATH = "/passkey/register" + PASSKEY_CHALLENGE_PATH = "/passkey/challenge" # ============================================================================ # INITIALIZATION @@ -2480,22 +2490,21 @@ async def login_with_custom_token_exchange( ) # ============================================================================ - # PASSKEY AUTHENTICATION + # MFA (Multi-Factor Authentication) # ============================================================================ - GRANT_TYPE_PASSKEY = "urn:okta:params:oauth:grant-type:webauthn" + @property + def mfa(self) -> MfaClient: + """Access the MFA client for multi-factor authentication operations.""" + return self._mfa_client + + # ============================================================================ + # PASSKEY AUTHENTICATION + # ============================================================================ async def passkey_signup_challenge( self, - name: Optional[str] = None, - email: Optional[str] = None, - username: Optional[str] = None, - phone_number: Optional[str] = None, - given_name: Optional[str] = None, - family_name: Optional[str] = None, - nickname: Optional[str] = None, - picture: Optional[str] = None, - user_metadata: Optional[dict[str, Any]] = None, + user_profile: Optional[PasskeyUserProfile] = None, connection: Optional[str] = None, organization: Optional[str] = None, store_options: Optional[dict[str, Any]] = None, @@ -2507,15 +2516,8 @@ async def passkey_signup_challenge( then call signin_with_passkey() with the auth_session and credential result. Args: - name: User's full name. - email: User's email address. - username: Username for the new account. - phone_number: User's phone number. - given_name: User's given (first) name. - family_name: User's family (last) name. - nickname: User's nickname. - picture: URL to the user's profile picture. - user_metadata: Arbitrary user metadata dict. + user_profile: Optional user profile data (email, name, username, etc.). + Use PasskeyUserProfile — supports extra fields for forward compatibility. connection: Auth0 database connection name (realm). organization: Auth0 organization ID or name. store_options: Optional options for domain resolution. @@ -2524,73 +2526,52 @@ async def passkey_signup_challenge( PasskeySignupChallengeResponse with auth_session and authn_params_public_key. Raises: - ApiError: If the challenge request fails. + PasskeyError: If the challenge request fails. """ try: domain = await self._resolve_current_domain(store_options) - user_profile: dict[str, Any] = {} - if email is not None: - user_profile["email"] = email - if name is not None: - user_profile["name"] = name - if username is not None: - user_profile["username"] = username - if phone_number is not None: - user_profile["phone_number"] = phone_number - if given_name is not None: - user_profile["given_name"] = given_name - if family_name is not None: - user_profile["family_name"] = family_name - if nickname is not None: - user_profile["nickname"] = nickname - if picture is not None: - user_profile["picture"] = picture - if user_metadata is not None: - user_profile["user_metadata"] = user_metadata - body: dict[str, Any] = {"client_id": self._client_id} if self._client_secret: body["client_secret"] = self._client_secret if user_profile: - body["user_profile"] = user_profile + body["user_profile"] = user_profile.model_dump(exclude_none=True) if connection: body["realm"] = connection if organization: body["organization"] = organization - url = f"https://{domain}/passkey/register" - async with self._get_http_client() as client: + url = f"https://{domain}{self.PASSKEY_REGISTER_PATH}" response = await client.post(url, json=body) if response.status_code != 200: try: error_data = response.json() except (json.JSONDecodeError, ValueError): - raise ApiError( - "passkey_challenge_error", + raise PasskeyError( + PasskeyErrorCode.CHALLENGE_FAILED, f"Passkey signup challenge failed with status {response.status_code}", ) - raise ApiError( - error_data.get("error", "passkey_challenge_error"), + raise PasskeyError( + error_data.get("error", PasskeyErrorCode.CHALLENGE_FAILED), error_data.get("error_description", "Passkey signup challenge failed"), ) try: data = response.json() except (json.JSONDecodeError, ValueError): - raise ApiError( - "invalid_response", + raise PasskeyError( + PasskeyErrorCode.INVALID_RESPONSE, "Failed to parse passkey signup challenge response as JSON", ) return PasskeySignupChallengeResponse.model_validate(data) except Exception as e: - if isinstance(e, (ApiError, MissingRequiredArgumentError, ValidationError)): + if isinstance(e, (PasskeyError, MissingRequiredArgumentError, ValidationError)): raise - raise ApiError("passkey_challenge_error", "Passkey signup challenge failed", e) + raise PasskeyError(PasskeyErrorCode.CHALLENGE_FAILED, "Passkey signup challenge failed", e) from e async def passkey_login_challenge( self, @@ -2630,38 +2611,37 @@ async def passkey_login_challenge( if organization: body["organization"] = organization - url = f"https://{domain}/passkey/challenge" - async with self._get_http_client() as client: + url = f"https://{domain}{self.PASSKEY_CHALLENGE_PATH}" response = await client.post(url, json=body) if response.status_code != 200: try: error_data = response.json() except (json.JSONDecodeError, ValueError): - raise ApiError( - "passkey_challenge_error", + raise PasskeyError( + PasskeyErrorCode.CHALLENGE_FAILED, f"Passkey login challenge failed with status {response.status_code}", ) - raise ApiError( - error_data.get("error", "passkey_challenge_error"), + raise PasskeyError( + error_data.get("error", PasskeyErrorCode.CHALLENGE_FAILED), error_data.get("error_description", "Passkey login challenge failed"), ) try: data = response.json() except (json.JSONDecodeError, ValueError): - raise ApiError( - "invalid_response", + raise PasskeyError( + PasskeyErrorCode.INVALID_RESPONSE, "Failed to parse passkey login challenge response as JSON", ) return PasskeyLoginChallengeResponse.model_validate(data) except Exception as e: - if isinstance(e, (ApiError, MissingRequiredArgumentError, ValidationError)): + if isinstance(e, (PasskeyError, MissingRequiredArgumentError, ValidationError)): raise - raise ApiError("passkey_challenge_error", "Passkey login challenge failed", e) + raise PasskeyError(PasskeyErrorCode.CHALLENGE_FAILED, "Passkey login challenge failed", e) from e async def signin_with_passkey( self, @@ -2672,6 +2652,7 @@ async def signin_with_passkey( organization: Optional[str] = None, scope: Optional[str] = None, audience: Optional[str] = None, + dpop_key: Optional["jwk.JWK"] = None, ) -> PasskeyTokenResponse: """ Completes passkey authentication by exchanging the WebAuthn assertion @@ -2690,13 +2671,16 @@ async def signin_with_passkey( organization: Auth0 organization ID or name. scope: OAuth2 scope string. audience: Target API audience. + dpop_key: Optional EC P-256 JWK for DPoP-bound token exchange. When provided, + attaches a DPoP proof header so Auth0 issues a DPoP-bound token + (token_type: DPoP). Required when the tenant mandates DPoP binding. Returns: PasskeyTokenResponse containing access_token, id_token, expires_in, etc. Raises: MissingRequiredArgumentError: If auth_session or authn_response is missing. - ApiError: If token exchange fails. + PasskeyError: If token exchange fails. """ if not auth_session: raise MissingRequiredArgumentError("auth_session") @@ -2709,7 +2693,7 @@ async def signin_with_passkey( token_endpoint = metadata.get("token_endpoint") if not token_endpoint: - raise ApiError("configuration_error", "Token endpoint missing in OIDC metadata") + raise PasskeyError(PasskeyErrorCode.TOKEN_EXCHANGE_FAILED, "Token endpoint missing in OIDC metadata") body: dict[str, Any] = { "grant_type": self.GRANT_TYPE_PASSKEY, @@ -2729,26 +2713,43 @@ async def signin_with_passkey( body["audience"] = audience async with self._get_http_client() as client: - response = await client.post(token_endpoint, json=body) + headers = {} + if dpop_key is not None: + headers["DPoP"] = make_dpop_proof_for_token_endpoint( + dpop_key, "POST", token_endpoint + ) + response = await client.post(token_endpoint, json=body, headers=headers) + + # RFC 9449 §8.2 — nonce retry for DPoP token endpoint calls + if ( + dpop_key is not None + and response.status_code == 401 + and response.headers.get("DPoP-Nonce") + ): + nonce = response.headers["DPoP-Nonce"] + headers["DPoP"] = make_dpop_proof_for_token_endpoint( + dpop_key, "POST", token_endpoint, nonce=nonce + ) + response = await client.post(token_endpoint, json=body, headers=headers) if response.status_code != 200: try: error_data = response.json() except (json.JSONDecodeError, ValueError): - raise ApiError( - "passkey_token_error", + raise PasskeyError( + PasskeyErrorCode.TOKEN_EXCHANGE_FAILED, f"Passkey token exchange failed with status {response.status_code}", ) - raise ApiError( - error_data.get("error", "passkey_token_error"), + raise PasskeyError( + error_data.get("error", PasskeyErrorCode.TOKEN_EXCHANGE_FAILED), error_data.get("error_description", "Passkey token exchange failed"), ) try: token_data = response.json() except (json.JSONDecodeError, ValueError): - raise ApiError( - "invalid_response", "Failed to parse passkey token response as JSON" + raise PasskeyError( + PasskeyErrorCode.INVALID_RESPONSE, "Failed to parse passkey token response as JSON" ) if "expires_in" in token_data and "expires_at" not in token_data: @@ -2757,15 +2758,6 @@ async def signin_with_passkey( return PasskeyTokenResponse.model_validate(token_data) except Exception as e: - if isinstance(e, (ApiError, MissingRequiredArgumentError, ValidationError)): + if isinstance(e, (PasskeyError, MissingRequiredArgumentError, ValidationError)): raise - raise ApiError("passkey_token_error", "Passkey sign-in failed", e) - - # ============================================================================ - # MFA (Multi-Factor Authentication) - # ============================================================================ - - @property - def mfa(self) -> MfaClient: - """Access the MFA client for multi-factor authentication operations.""" - return self._mfa_client + raise PasskeyError(PasskeyErrorCode.TOKEN_EXCHANGE_FAILED, "Passkey sign-in failed", e) from e diff --git a/src/auth0_server_python/auth_types/__init__.py b/src/auth0_server_python/auth_types/__init__.py index 8141a18..9494a22 100644 --- a/src/auth0_server_python/auth_types/__init__.py +++ b/src/auth0_server_python/auth_types/__init__.py @@ -465,200 +465,6 @@ class ListConnectedAccountConnectionsResponse(BaseModel): next: Optional[str] = None -# ============================================================================= -# Passkey & MyAccount Authentication Methods Types -# ============================================================================= - - -class PasskeyRpInfo(BaseModel): - id: str - name: str - - -class PasskeyUserInfo(BaseModel): - model_config = ConfigDict(populate_by_name=True) - id: str - name: str - display_name: Optional[str] = Field(None, alias="displayName") - - -class PasskeyPubKeyCredParam(BaseModel): - type: str - alg: int - - -class PasskeyAuthenticatorSelection(BaseModel): - model_config = ConfigDict(populate_by_name=True) - resident_key: Optional[str] = Field(None, alias="residentKey") - user_verification: Optional[str] = Field(None, alias="userVerification") - - -class PasskeyPublicKeyOptions(BaseModel): - model_config = ConfigDict(populate_by_name=True) - challenge: str - rp: Optional[PasskeyRpInfo] = None - rp_id: Optional[str] = Field(None, alias="rpId") - user: Optional[PasskeyUserInfo] = None - pub_key_cred_params: Optional[list[PasskeyPubKeyCredParam]] = Field( - None, alias="pubKeyCredParams" - ) - authenticator_selection: Optional[PasskeyAuthenticatorSelection] = Field( - None, alias="authenticatorSelection" - ) - timeout: Optional[int] = None - user_verification: Optional[str] = Field(None, alias="userVerification") - - -class EnrollAuthenticationMethodRequest(BaseModel): - type: str - email: Optional[str] = None - phone_number: Optional[str] = None - preferred_authentication_method: Optional[str] = None - user_identity_id: Optional[str] = None - connection: Optional[str] = None - - -class EnrollmentChallengeResponse(BaseModel): - authentication_method_id: str - auth_session: str - authn_params_public_key: Optional[PasskeyPublicKeyOptions] = None - - def __repr__(self) -> str: - return ( - f"EnrollmentChallengeResponse(" - f"authentication_method_id={self.authentication_method_id!r}, " - f"auth_session=[REDACTED], " - f"authn_params_public_key={self.authn_params_public_key!r})" - ) - - -class PasskeyAuthResponse(BaseModel): - model_config = ConfigDict(populate_by_name=True) - id: str - raw_id: str = Field(alias="rawId") - type: str - authenticator_attachment: Optional[str] = Field(None, alias="authenticatorAttachment") - response: dict[str, str] - client_extension_results: Optional[dict] = Field(None, alias="clientExtensionResults") - - -class VerifyAuthenticationMethodRequest(BaseModel): - auth_session: str - authn_response: Optional[PasskeyAuthResponse] = None - otp_code: Optional[str] = None - recovery_code: Optional[str] = None - password: Optional[str] = None - - @model_validator(mode="after") - def _check_at_least_one_method(self) -> "VerifyAuthenticationMethodRequest": - has_method = ( - self.authn_response is not None - or (self.otp_code is not None and self.otp_code.strip() != "") - or (self.recovery_code is not None and self.recovery_code.strip() != "") - or (self.password is not None and self.password.strip() != "") - ) - if not has_method: - raise ValueError( - "At least one verification method must be provided: " - "authn_response, otp_code, recovery_code, or password" - ) - return self - - -class AuthenticationMethod(BaseModel): - model_config = ConfigDict(extra="allow", populate_by_name=True) - - id: str - type: str - created_at: str - confirmed: Optional[bool] = None - usage: Optional[list[str]] = None - identity_user_id: Optional[str] = None - credential_device_type: Optional[str] = None - credential_backed_up: Optional[bool] = None - key_id: Optional[str] = None - public_key: Optional[str] = None - transports: Optional[list[str]] = None - user_agent: Optional[str] = None - user_handle: Optional[str] = None - aaguid: Optional[str] = None - relying_party_id: Optional[str] = None - phone_number: Optional[str] = None - preferred_authentication_method: Optional[str] = None - email: Optional[str] = None - name: Optional[str] = None - last_password_reset: Optional[str] = None - - -class UpdateAuthenticationMethodRequest(BaseModel): - name: Optional[str] = None - preferred_authentication_method: Optional[str] = None - - -class ListAuthenticationMethodsResponse(BaseModel): - authentication_methods: list[AuthenticationMethod] - - -class Factor(BaseModel): - model_config = ConfigDict(extra="allow") - name: str - enabled: Optional[bool] = None - trial_expired: Optional[bool] = None - - -class GetFactorsResponse(BaseModel): - factors: list[Factor] - - -class PasskeySignupChallengeResponse(BaseModel): - model_config = ConfigDict(populate_by_name=True) - auth_session: str - authn_params_public_key: PasskeyPublicKeyOptions - - def __repr__(self) -> str: - return ( - f"PasskeySignupChallengeResponse(" - f"auth_session=[REDACTED], " - f"authn_params_public_key={self.authn_params_public_key!r})" - ) - - -class PasskeyLoginChallengeResponse(BaseModel): - model_config = ConfigDict(populate_by_name=True) - auth_session: str - authn_params_public_key: PasskeyPublicKeyOptions - - def __repr__(self) -> str: - return ( - f"PasskeyLoginChallengeResponse(" - f"auth_session=[REDACTED], " - f"authn_params_public_key={self.authn_params_public_key!r})" - ) - - -class PasskeyTokenResponse(BaseModel): - model_config = ConfigDict(extra="allow", populate_by_name=True) - access_token: str - token_type: str = "Bearer" - expires_in: int - expires_at: Optional[int] = None - scope: Optional[str] = None - id_token: Optional[str] = None - refresh_token: Optional[str] = None - - def __repr__(self) -> str: - return ( - f"PasskeyTokenResponse(" - f"token_type={self.token_type!r}, " - f"expires_in={self.expires_in!r}, " - f"expires_at={self.expires_at!r}, " - f"scope={self.scope!r}, " - f"access_token=[REDACTED], " - f"id_token=[REDACTED], " - f"refresh_token=[REDACTED])" - ) - - # ============================================================================= # MFA Types # ============================================================================= @@ -839,3 +645,197 @@ class MfaTokenContext(BaseModel): scope: str mfa_requirements: Optional[MfaRequirements] = None created_at: int + + +# ============================================================================= +# Passkey & MyAccount Authentication Methods Types +# ============================================================================= + + +class PasskeyRpInfo(BaseModel): + id: str + name: str + + +class PasskeyUserInfo(BaseModel): + model_config = ConfigDict(populate_by_name=True) + id: str + name: str + display_name: Optional[str] = Field(None, alias="displayName") + + +class PasskeyPubKeyCredParam(BaseModel): + type: str + alg: int + + +class PasskeyAuthenticatorSelection(BaseModel): + model_config = ConfigDict(populate_by_name=True) + resident_key: Optional[str] = Field(None, alias="residentKey") + user_verification: Optional[str] = Field(None, alias="userVerification") + + +class PasskeyPublicKeyOptions(BaseModel): + model_config = ConfigDict(populate_by_name=True) + challenge: str + rp: Optional[PasskeyRpInfo] = None + rp_id: Optional[str] = Field(None, alias="rpId") + user: Optional[PasskeyUserInfo] = None + pub_key_cred_params: Optional[list[PasskeyPubKeyCredParam]] = Field( + None, alias="pubKeyCredParams" + ) + authenticator_selection: Optional[PasskeyAuthenticatorSelection] = Field( + None, alias="authenticatorSelection" + ) + timeout: Optional[int] = None + user_verification: Optional[str] = Field(None, alias="userVerification") + + +EnrollmentType = Literal["passkey", "email", "phone", "totp", "push-notification", "recovery-code", "password"] +PreferredAuthMethod = Literal["sms", "voice"] + + +class EnrollAuthenticationMethodRequest(BaseModel): + type: EnrollmentType + email: Optional[str] = None + phone_number: Optional[str] = None + preferred_authentication_method: Optional[PreferredAuthMethod] = None + user_identity_id: Optional[str] = None + connection: Optional[str] = None + + +class EnrollmentChallengeResponse(BaseModel): + authentication_method_id: str + auth_session: str + authn_params_public_key: Optional[PasskeyPublicKeyOptions] = None + + def __repr__(self) -> str: + return ( + f"EnrollmentChallengeResponse(" + f"authentication_method_id={self.authentication_method_id!r}, " + f"auth_session=[REDACTED], " + f"authn_params_public_key={self.authn_params_public_key!r})" + ) + + +class PasskeyAuthResponse(BaseModel): + model_config = ConfigDict(populate_by_name=True) + id: str + raw_id: str = Field(alias="rawId") + type: str + authenticator_attachment: Optional[str] = Field(None, alias="authenticatorAttachment") + response: dict[str, str] + client_extension_results: Optional[dict] = Field(None, alias="clientExtensionResults") + + +class VerifyAuthenticationMethodRequest(BaseModel): + auth_session: str + authn_response: Optional[PasskeyAuthResponse] = None + otp_code: Optional[str] = None + recovery_code: Optional[str] = None + password: Optional[str] = None + + +class AuthenticationMethod(BaseModel): + model_config = ConfigDict(extra="allow") + + id: str + type: str + created_at: str + confirmed: Optional[bool] = None + usage: Optional[list[str]] = None + identity_user_id: Optional[str] = None + credential_device_type: Optional[str] = None + credential_backed_up: Optional[bool] = None + key_id: Optional[str] = None + public_key: Optional[str] = None + transports: Optional[list[str]] = None + user_agent: Optional[str] = None + user_handle: Optional[str] = None + aaguid: Optional[str] = None + relying_party_id: Optional[str] = None + phone_number: Optional[str] = None + preferred_authentication_method: Optional[str] = None + email: Optional[str] = None + name: Optional[str] = None + last_password_reset: Optional[str] = None + + +class UpdateAuthenticationMethodRequest(BaseModel): + name: Optional[str] = None + preferred_authentication_method: Optional[str] = None + + +class ListAuthenticationMethodsResponse(BaseModel): + authentication_methods: list[AuthenticationMethod] + + +class Factor(BaseModel): + model_config = ConfigDict(extra="allow") + name: str + enabled: Optional[bool] = None + trial_expired: Optional[bool] = None + + +class GetFactorsResponse(BaseModel): + factors: list[Factor] + + +class PasskeyUserProfile(BaseModel): + model_config = ConfigDict(extra="allow") + email: Optional[str] = None + name: Optional[str] = None + username: Optional[str] = None + phone_number: Optional[str] = None + given_name: Optional[str] = None + family_name: Optional[str] = None + nickname: Optional[str] = None + picture: Optional[str] = None + user_metadata: Optional[dict[str, Any]] = None + + +class PasskeySignupChallengeResponse(BaseModel): + auth_session: str + authn_params_public_key: PasskeyPublicKeyOptions + + def __repr__(self) -> str: + return ( + f"PasskeySignupChallengeResponse(" + f"auth_session=[REDACTED], " + f"authn_params_public_key={self.authn_params_public_key!r})" + ) + + +class PasskeyLoginChallengeResponse(BaseModel): + auth_session: str + authn_params_public_key: PasskeyPublicKeyOptions + + def __repr__(self) -> str: + return ( + f"PasskeyLoginChallengeResponse(" + f"auth_session=[REDACTED], " + f"authn_params_public_key={self.authn_params_public_key!r})" + ) + + +class PasskeyTokenResponse(BaseModel): + model_config = ConfigDict(extra="allow") + access_token: str + token_type: str = "Bearer" + expires_in: int + expires_at: Optional[int] = None + scope: Optional[str] = None + id_token: Optional[str] = None + refresh_token: Optional[str] = None + + def __repr__(self) -> str: + return ( + f"PasskeyTokenResponse(" + f"token_type={self.token_type!r}, " + f"expires_in={self.expires_in!r}, " + f"expires_at={self.expires_at!r}, " + f"scope={self.scope!r}, " + f"access_token=[REDACTED], " + f"id_token=[REDACTED], " + f"refresh_token=[REDACTED])" + ) diff --git a/src/auth0_server_python/error/__init__.py b/src/auth0_server_python/error/__init__.py index db4f28e..615c112 100644 --- a/src/auth0_server_python/error/__init__.py +++ b/src/auth0_server_python/error/__init__.py @@ -229,6 +229,24 @@ class CustomTokenExchangeErrorCode: INVALID_RESPONSE = "invalid_response" +class PasskeyError(Auth0Error): + """ + Error raised during passkey authentication operations. + """ + def __init__(self, code: str, message: str, cause=None): + super().__init__(message) + self.code = code + self.name = "PasskeyError" + self.cause = cause + + +class PasskeyErrorCode: + """Error codes for passkey operations.""" + CHALLENGE_FAILED = "passkey_challenge_error" + TOKEN_EXCHANGE_FAILED = "passkey_token_error" + INVALID_RESPONSE = "invalid_response" + + # ============================================================================= # MFA Error Classes # ============================================================================= diff --git a/src/auth0_server_python/tests/test_my_account_client.py b/src/auth0_server_python/tests/test_my_account_client.py index e4ff74c..da2875d 100644 --- a/src/auth0_server_python/tests/test_my_account_client.py +++ b/src/auth0_server_python/tests/test_my_account_client.py @@ -1,9 +1,13 @@ from unittest.mock import ANY, AsyncMock, MagicMock +import httpx import pytest +from jwcrypto import jwk as jwk_module +from auth0_server_python.auth_schemes.dpop_auth import DPoPAuth from auth0_server_python.auth_server.my_account_client import MyAccountClient from auth0_server_python.auth_types import ( + AuthenticationMethod, CompleteConnectAccountRequest, CompleteConnectAccountResponse, ConnectAccountRequest, @@ -11,10 +15,18 @@ ConnectedAccount, ConnectedAccountConnection, ConnectParams, + EnrollAuthenticationMethodRequest, + EnrollmentChallengeResponse, + GetFactorsResponse, + ListAuthenticationMethodsResponse, ListConnectedAccountConnectionsResponse, ListConnectedAccountsResponse, + PasskeyAuthResponse, + UpdateAuthenticationMethodRequest, + VerifyAuthenticationMethodRequest, ) from auth0_server_python.error import ( + ApiError, InvalidArgumentError, MissingRequiredArgumentError, MyAccountApiError, @@ -502,3 +514,826 @@ async def test_list_connected_account_connections_api_response_failure(mocker): mock_get.assert_awaited_once() assert "Invalid Token" in str(exc.value) + +# ============================================================================= +# AUTHENTICATION METHODS & FACTORS (Passkey / MyAccount API) +# ============================================================================= + + +@pytest.mark.asyncio +async def test_get_factors_success(mocker): + client = MyAccountClient(domain="auth0.local") + response = AsyncMock() + response.status_code = 200 + response.json = MagicMock(return_value={"factors": [{"name": "sms", "enabled": True}]}) + mocker.patch("httpx.AsyncClient.get", new_callable=AsyncMock, return_value=response) + + result = await client.get_factors(access_token="token123") + + assert isinstance(result, GetFactorsResponse) + assert len(result.factors) == 1 + assert result.factors[0].name == "sms" + assert result.factors[0].enabled is True + + +@pytest.mark.asyncio +@pytest.mark.parametrize("access_token", [None, ""]) +async def test_get_factors_missing_access_token(mocker, access_token): + client = MyAccountClient(domain="auth0.local") + mock_get = mocker.patch("httpx.AsyncClient.get", new_callable=AsyncMock) + + with pytest.raises(MissingRequiredArgumentError): + await client.get_factors(access_token=access_token) + + mock_get.assert_not_awaited() + + +@pytest.mark.asyncio +async def test_get_factors_api_error(mocker): + client = MyAccountClient(domain="auth0.local") + response = AsyncMock() + response.status_code = 403 + response.json = MagicMock(return_value={ + "title": "Forbidden", + "type": "forbidden", + "detail": "Insufficient scope", + "status": 403, + }) + mocker.patch("httpx.AsyncClient.get", new_callable=AsyncMock, return_value=response) + + with pytest.raises(MyAccountApiError) as exc: + await client.get_factors(access_token="token123") + + assert exc.value.status == 403 + + +@pytest.mark.asyncio +async def test_get_factors_network_error(mocker): + client = MyAccountClient(domain="auth0.local") + mocker.patch( + "httpx.AsyncClient.get", new_callable=AsyncMock, side_effect=Exception("Connection refused") + ) + + with pytest.raises(ApiError): + await client.get_factors(access_token="token123") + + +@pytest.mark.asyncio +async def test_get_factors_empty_list(mocker): + client = MyAccountClient(domain="auth0.local") + response = AsyncMock() + response.status_code = 200 + response.json = MagicMock(return_value={"factors": []}) + mocker.patch("httpx.AsyncClient.get", new_callable=AsyncMock, return_value=response) + + result = await client.get_factors(access_token="token123") + assert result.factors == [] + + +@pytest.mark.asyncio +async def test_get_factors_extra_fields(mocker): + client = MyAccountClient(domain="auth0.local") + response = AsyncMock() + response.status_code = 200 + response.json = MagicMock(return_value={ + "factors": [{"name": "webauthn-roaming", "enabled": True, "future_field": "value"}] + }) + mocker.patch("httpx.AsyncClient.get", new_callable=AsyncMock, return_value=response) + + result = await client.get_factors(access_token="token123") + assert result.factors[0].name == "webauthn-roaming" + + +@pytest.mark.asyncio +async def test_list_authentication_methods_success(mocker): + client = MyAccountClient(domain="auth0.local") + response = AsyncMock() + response.status_code = 200 + response.json = MagicMock(return_value={ + "authentication_methods": [ + {"id": "am_1", "type": "passkey", "created_at": "2026-01-01T00:00:00Z", "key_id": "kid1"} + ] + }) + mocker.patch("httpx.AsyncClient.get", new_callable=AsyncMock, return_value=response) + + result = await client.list_authentication_methods(access_token="token123") + assert isinstance(result, ListAuthenticationMethodsResponse) + assert len(result.authentication_methods) == 1 + assert result.authentication_methods[0].type == "passkey" + + +@pytest.mark.asyncio +async def test_list_authentication_methods_with_type_filter(mocker): + client = MyAccountClient(domain="auth0.local") + response = AsyncMock() + response.status_code = 200 + response.json = MagicMock(return_value={"authentication_methods": []}) + mock_get = mocker.patch("httpx.AsyncClient.get", new_callable=AsyncMock, return_value=response) + + await client.list_authentication_methods(access_token="token123", type_filter="passkey") + mock_get.assert_awaited_once() + call_kwargs = mock_get.call_args[1] + assert call_kwargs["params"] == {"type": "passkey"} + + +@pytest.mark.asyncio +async def test_list_authentication_methods_empty(mocker): + client = MyAccountClient(domain="auth0.local") + response = AsyncMock() + response.status_code = 200 + response.json = MagicMock(return_value={"authentication_methods": []}) + mocker.patch("httpx.AsyncClient.get", new_callable=AsyncMock, return_value=response) + + result = await client.list_authentication_methods(access_token="token123") + assert result.authentication_methods == [] + + +@pytest.mark.asyncio +async def test_get_authentication_method_success(mocker): + client = MyAccountClient(domain="auth0.local") + response = AsyncMock() + response.status_code = 200 + response.json = MagicMock(return_value={ + "id": "am_1", "type": "passkey", "created_at": "2026-01-01T00:00:00Z" + }) + mocker.patch("httpx.AsyncClient.get", new_callable=AsyncMock, return_value=response) + + result = await client.get_authentication_method( + access_token="token123", authentication_method_id="am_1" + ) + assert isinstance(result, AuthenticationMethod) + assert result.id == "am_1" + + +@pytest.mark.asyncio +@pytest.mark.parametrize("method_id", [None, ""]) +async def test_get_authentication_method_missing_id(mocker, method_id): + client = MyAccountClient(domain="auth0.local") + mock_get = mocker.patch("httpx.AsyncClient.get", new_callable=AsyncMock) + + with pytest.raises(MissingRequiredArgumentError): + await client.get_authentication_method( + access_token="token123", authentication_method_id=method_id + ) + + mock_get.assert_not_awaited() + + +@pytest.mark.asyncio +async def test_get_authentication_method_path_traversal(mocker): + client = MyAccountClient(domain="auth0.local") + response = AsyncMock() + response.status_code = 200 + response.json = MagicMock(return_value={ + "id": "id/slash", "type": "passkey", "created_at": "2026-01-01T00:00:00Z" + }) + mock_get = mocker.patch("httpx.AsyncClient.get", new_callable=AsyncMock, return_value=response) + + await client.get_authentication_method( + access_token="token123", authentication_method_id="id/slash" + ) + call_url = mock_get.call_args[1]["url"] + assert "id%2Fslash" in call_url + assert "id/slash" not in call_url.replace("https://auth0.local/me/", "") + + +@pytest.mark.asyncio +async def test_get_authentication_method_pipe_encoding(mocker): + client = MyAccountClient(domain="auth0.local") + response = AsyncMock() + response.status_code = 200 + response.json = MagicMock(return_value={ + "id": "passkey|new", "type": "passkey", "created_at": "2026-01-01T00:00:00Z" + }) + mock_get = mocker.patch("httpx.AsyncClient.get", new_callable=AsyncMock, return_value=response) + + await client.get_authentication_method( + access_token="token123", authentication_method_id="passkey|new" + ) + call_url = mock_get.call_args[1]["url"] + assert "passkey%7Cnew" in call_url + + +@pytest.mark.asyncio +async def test_delete_authentication_method_success(mocker): + client = MyAccountClient(domain="auth0.local") + response = AsyncMock() + response.status_code = 204 + mocker.patch("httpx.AsyncClient.delete", new_callable=AsyncMock, return_value=response) + + result = await client.delete_authentication_method( + access_token="token123", authentication_method_id="am_1" + ) + assert result is None + + +@pytest.mark.asyncio +@pytest.mark.parametrize("method_id", [None, ""]) +async def test_delete_authentication_method_missing_id(mocker, method_id): + client = MyAccountClient(domain="auth0.local") + mock_delete = mocker.patch("httpx.AsyncClient.delete", new_callable=AsyncMock) + + with pytest.raises(MissingRequiredArgumentError): + await client.delete_authentication_method( + access_token="token123", authentication_method_id=method_id + ) + + mock_delete.assert_not_awaited() + + +@pytest.mark.asyncio +async def test_update_authentication_method_success(mocker): + client = MyAccountClient(domain="auth0.local") + response = AsyncMock() + response.status_code = 200 + response.json = MagicMock(return_value={ + "id": "am_1", "type": "passkey", "created_at": "2026-01-01T00:00:00Z", "name": "My Key", + }) + mock_patch = mocker.patch( + "httpx.AsyncClient.patch", new_callable=AsyncMock, return_value=response + ) + + req = UpdateAuthenticationMethodRequest(name="My Key") + result = await client.update_authentication_method( + access_token="token123", authentication_method_id="am_1", request=req + ) + assert result.name == "My Key" + call_kwargs = mock_patch.call_args[1] + assert call_kwargs["json"] == {"name": "My Key"} + + +@pytest.mark.asyncio +async def test_update_authentication_method_missing_request(mocker): + client = MyAccountClient(domain="auth0.local") + mocker.patch("httpx.AsyncClient.patch", new_callable=AsyncMock) + + with pytest.raises(MissingRequiredArgumentError): + await client.update_authentication_method( + access_token="token123", authentication_method_id="am_1", request=None + ) + + +@pytest.mark.asyncio +async def test_enroll_authentication_method_success(mocker): + client = MyAccountClient(domain="auth0.local") + response = AsyncMock() + response.status_code = 202 + response.headers = {"location": "/me/v1/authentication-methods/passkey|new"} + response.json = MagicMock(return_value={ + "auth_session": "session_abc", + "authn_params_public_key": { + "challenge": "dGVzdA", + "rp": {"id": "auth0.local", "name": "My App"}, + "user": {"id": "dXNlcl8x", "name": "user@test.com", "displayName": "Test User"}, + "pubKeyCredParams": [{"type": "public-key", "alg": -7}], + "authenticatorSelection": {"residentKey": "required", "userVerification": "preferred"}, + "timeout": 60000, + }, + }) + mocker.patch("httpx.AsyncClient.post", new_callable=AsyncMock, return_value=response) + + req = EnrollAuthenticationMethodRequest(type="passkey") + result = await client.enroll_authentication_method(access_token="token123", request=req) + + assert isinstance(result, EnrollmentChallengeResponse) + assert result.authentication_method_id == "passkey|new" + assert result.auth_session == "session_abc" + assert result.authn_params_public_key is not None + assert result.authn_params_public_key.pub_key_cred_params[0].alg == -7 + assert result.authn_params_public_key.authenticator_selection.resident_key == "required" + assert result.authn_params_public_key.user.display_name == "Test User" + + +@pytest.mark.asyncio +async def test_enroll_authentication_method_missing_location(mocker): + client = MyAccountClient(domain="auth0.local") + response = AsyncMock() + response.status_code = 202 + response.headers = {} + response.json = MagicMock(return_value={"auth_session": "session_abc"}) + mocker.patch("httpx.AsyncClient.post", new_callable=AsyncMock, return_value=response) + + req = EnrollAuthenticationMethodRequest(type="passkey") + with pytest.raises(ApiError) as exc: + await client.enroll_authentication_method(access_token="token123", request=req) + + assert "Location header" in str(exc.value) + + +@pytest.mark.asyncio +async def test_enroll_authentication_method_location_with_query(mocker): + client = MyAccountClient(domain="auth0.local") + response = AsyncMock() + response.status_code = 202 + response.headers = {"location": "/me/v1/authentication-methods/abc123?tracking=1"} + response.json = MagicMock(return_value={"auth_session": "session_abc"}) + mocker.patch("httpx.AsyncClient.post", new_callable=AsyncMock, return_value=response) + + req = EnrollAuthenticationMethodRequest(type="passkey") + result = await client.enroll_authentication_method(access_token="token123", request=req) + assert result.authentication_method_id == "abc123" + + +@pytest.mark.asyncio +async def test_enroll_authentication_method_location_absolute_url(mocker): + client = MyAccountClient(domain="auth0.local") + response = AsyncMock() + response.status_code = 202 + response.headers = {"location": "https://tenant.auth0.com/me/v1/authentication-methods/am_xyz"} + response.json = MagicMock(return_value={"auth_session": "session_abc"}) + mocker.patch("httpx.AsyncClient.post", new_callable=AsyncMock, return_value=response) + + req = EnrollAuthenticationMethodRequest(type="passkey") + result = await client.enroll_authentication_method(access_token="token123", request=req) + assert result.authentication_method_id == "am_xyz" + + +@pytest.mark.asyncio +async def test_verify_authentication_method_success(mocker): + client = MyAccountClient(domain="auth0.local") + response = AsyncMock() + response.status_code = 201 + response.json = MagicMock(return_value={ + "id": "am_1", "type": "passkey", "created_at": "2026-01-01T00:00:00Z", "confirmed": True, + }) + mock_post = mocker.patch( + "httpx.AsyncClient.post", new_callable=AsyncMock, return_value=response + ) + + authn_response = PasskeyAuthResponse( + id="cred1", + raw_id="cmF3MQ", + type="public-key", + authenticator_attachment="platform", + response={"clientDataJSON": "abc", "attestationObject": "def"}, + ) + req = VerifyAuthenticationMethodRequest( + auth_session="session_abc", authn_response=authn_response + ) + result = await client.verify_authentication_method( + access_token="token123", authentication_method_id="passkey|new", request=req + ) + + assert isinstance(result, AuthenticationMethod) + assert result.confirmed is True + + call_kwargs = mock_post.call_args[1] + body = call_kwargs["json"] + assert "rawId" in body["authn_response"] + assert "raw_id" not in body["authn_response"] + assert "authenticatorAttachment" in body["authn_response"] + assert body["auth_session"] == "session_abc" + assert "passkey%7Cnew" in call_kwargs["url"] + + +@pytest.mark.asyncio +@pytest.mark.parametrize("method_id", [None, ""]) +async def test_verify_authentication_method_missing_id(mocker, method_id): + client = MyAccountClient(domain="auth0.local") + mocker.patch("httpx.AsyncClient.post", new_callable=AsyncMock) + + req = VerifyAuthenticationMethodRequest(auth_session="session_abc", otp_code="123456") + with pytest.raises(MissingRequiredArgumentError): + await client.verify_authentication_method( + access_token="token123", authentication_method_id=method_id, request=req + ) + + +def test_enrollment_challenge_response_repr(): + resp = EnrollmentChallengeResponse( + authentication_method_id="am_1", + auth_session="super_secret_session", + authn_params_public_key=None, + ) + repr_str = repr(resp) + assert "super_secret_session" not in repr_str + assert "[REDACTED]" in repr_str + assert "am_1" in repr_str + + +def test_verify_request_auth_session_only_is_valid(): + req = VerifyAuthenticationMethodRequest(auth_session="session_abc") + assert req.auth_session == "session_abc" + assert req.otp_code is None + assert req.authn_response is None + + +def test_verify_request_accepts_otp_code(): + req = VerifyAuthenticationMethodRequest(auth_session="session_abc", otp_code="123456") + assert req.otp_code == "123456" + + +def test_verify_request_accepts_authn_response(): + authn_resp = PasskeyAuthResponse( + id="cred1", + raw_id="cmF3MQ", + type="public-key", + response={"clientDataJSON": "abc", "attestationObject": "def"}, + ) + req = VerifyAuthenticationMethodRequest(auth_session="session_abc", authn_response=authn_resp) + assert req.authn_response is not None + + +@pytest.mark.asyncio +async def test_get_factors_with_dpop_key(mocker): + client = MyAccountClient(domain="auth0.local") + response = AsyncMock() + response.status_code = 200 + response.json = MagicMock(return_value={"factors": [{"name": "sms", "enabled": True}]}) + mock_get = mocker.patch("httpx.AsyncClient.get", new_callable=AsyncMock, return_value=response) + + dpop_key = jwk_module.JWK.generate(kty="EC", crv="P-256") + await client.get_factors(access_token="token123", dpop_key=dpop_key) + + mock_get.assert_awaited_once() + call_kwargs = mock_get.call_args[1] + assert isinstance(call_kwargs["auth"], DPoPAuth) + + +@pytest.mark.asyncio +async def test_list_authentication_methods_with_dpop_key(mocker): + client = MyAccountClient(domain="auth0.local") + response = AsyncMock() + response.status_code = 200 + response.json = MagicMock(return_value={"authentication_methods": []}) + mock_get = mocker.patch("httpx.AsyncClient.get", new_callable=AsyncMock, return_value=response) + + dpop_key = jwk_module.JWK.generate(kty="EC", crv="P-256") + await client.list_authentication_methods(access_token="token123", dpop_key=dpop_key) + + assert isinstance(mock_get.call_args[1]["auth"], DPoPAuth) + + +@pytest.mark.asyncio +async def test_get_authentication_method_with_dpop_key(mocker): + client = MyAccountClient(domain="auth0.local") + response = AsyncMock() + response.status_code = 200 + response.json = MagicMock(return_value={ + "id": "am_1", "type": "passkey", "created_at": "2026-01-01T00:00:00Z" + }) + mock_get = mocker.patch("httpx.AsyncClient.get", new_callable=AsyncMock, return_value=response) + + dpop_key = jwk_module.JWK.generate(kty="EC", crv="P-256") + await client.get_authentication_method( + access_token="token123", authentication_method_id="am_1", dpop_key=dpop_key + ) + + assert isinstance(mock_get.call_args[1]["auth"], DPoPAuth) + + +@pytest.mark.asyncio +async def test_delete_authentication_method_with_dpop_key(mocker): + client = MyAccountClient(domain="auth0.local") + response = AsyncMock() + response.status_code = 204 + mock_delete = mocker.patch( + "httpx.AsyncClient.delete", new_callable=AsyncMock, return_value=response + ) + + dpop_key = jwk_module.JWK.generate(kty="EC", crv="P-256") + await client.delete_authentication_method( + access_token="token123", authentication_method_id="am_1", dpop_key=dpop_key + ) + + assert isinstance(mock_delete.call_args[1]["auth"], DPoPAuth) + + +@pytest.mark.asyncio +async def test_update_authentication_method_with_dpop_key(mocker): + client = MyAccountClient(domain="auth0.local") + response = AsyncMock() + response.status_code = 200 + response.json = MagicMock(return_value={ + "id": "am_1", "type": "passkey", "created_at": "2026-01-01T00:00:00Z" + }) + mock_patch = mocker.patch( + "httpx.AsyncClient.patch", new_callable=AsyncMock, return_value=response + ) + + dpop_key = jwk_module.JWK.generate(kty="EC", crv="P-256") + req = UpdateAuthenticationMethodRequest(name="New Name") + await client.update_authentication_method( + access_token="token123", authentication_method_id="am_1", request=req, dpop_key=dpop_key + ) + + assert isinstance(mock_patch.call_args[1]["auth"], DPoPAuth) + + +@pytest.mark.asyncio +async def test_enroll_authentication_method_with_dpop_key(mocker): + client = MyAccountClient(domain="auth0.local") + response = AsyncMock() + response.status_code = 202 + response.headers = {"location": "/me/v1/authentication-methods/passkey|new"} + response.json = MagicMock(return_value={"auth_session": "session_abc"}) + mock_post = mocker.patch( + "httpx.AsyncClient.post", new_callable=AsyncMock, return_value=response + ) + + dpop_key = jwk_module.JWK.generate(kty="EC", crv="P-256") + req = EnrollAuthenticationMethodRequest(type="passkey") + await client.enroll_authentication_method( + access_token="token123", request=req, dpop_key=dpop_key + ) + + assert isinstance(mock_post.call_args[1]["auth"], DPoPAuth) + + +@pytest.mark.asyncio +async def test_verify_authentication_method_with_dpop_key(mocker): + client = MyAccountClient(domain="auth0.local") + response = AsyncMock() + response.status_code = 201 + response.json = MagicMock(return_value={ + "id": "am_1", "type": "passkey", "created_at": "2026-01-01T00:00:00Z" + }) + mock_post = mocker.patch( + "httpx.AsyncClient.post", new_callable=AsyncMock, return_value=response + ) + + dpop_key = jwk_module.JWK.generate(kty="EC", crv="P-256") + req = VerifyAuthenticationMethodRequest(auth_session="session_abc", otp_code="123456") + await client.verify_authentication_method( + access_token="token123", + authentication_method_id="am_1", + request=req, + dpop_key=dpop_key, + ) + + assert isinstance(mock_post.call_args[1]["auth"], DPoPAuth) + + +@pytest.mark.asyncio +async def test_list_authentication_methods_api_error(mocker): + client = MyAccountClient(domain="auth0.local") + response = AsyncMock() + response.status_code = 403 + response.json = MagicMock(return_value={ + "title": "Forbidden", "type": "forbidden", "detail": "Insufficient scope", "status": 403, + }) + mocker.patch("httpx.AsyncClient.get", new_callable=AsyncMock, return_value=response) + + with pytest.raises(MyAccountApiError) as exc: + await client.list_authentication_methods(access_token="token123") + assert exc.value.status == 403 + + +@pytest.mark.asyncio +async def test_list_authentication_methods_network_error(mocker): + client = MyAccountClient(domain="auth0.local") + mocker.patch( + "httpx.AsyncClient.get", new_callable=AsyncMock, side_effect=Exception("Connection refused") + ) + + with pytest.raises(ApiError): + await client.list_authentication_methods(access_token="token123") + + +@pytest.mark.asyncio +async def test_get_authentication_method_api_error(mocker): + client = MyAccountClient(domain="auth0.local") + response = AsyncMock() + response.status_code = 404 + response.json = MagicMock(return_value={ + "title": "Not Found", "type": "not_found", "detail": "Not found", "status": 404, + }) + mocker.patch("httpx.AsyncClient.get", new_callable=AsyncMock, return_value=response) + + with pytest.raises(MyAccountApiError) as exc: + await client.get_authentication_method( + access_token="token123", authentication_method_id="am_1" + ) + assert exc.value.status == 404 + + +@pytest.mark.asyncio +async def test_get_authentication_method_network_error(mocker): + client = MyAccountClient(domain="auth0.local") + mocker.patch("httpx.AsyncClient.get", new_callable=AsyncMock, side_effect=Exception("timeout")) + + with pytest.raises(ApiError): + await client.get_authentication_method( + access_token="token123", authentication_method_id="am_1" + ) + + +@pytest.mark.asyncio +async def test_delete_authentication_method_api_error(mocker): + client = MyAccountClient(domain="auth0.local") + response = AsyncMock() + response.status_code = 404 + response.json = MagicMock(return_value={ + "title": "Not Found", "type": "not_found", "detail": "Not found", "status": 404, + }) + mocker.patch("httpx.AsyncClient.delete", new_callable=AsyncMock, return_value=response) + + with pytest.raises(MyAccountApiError) as exc: + await client.delete_authentication_method( + access_token="token123", authentication_method_id="am_1" + ) + assert exc.value.status == 404 + + +@pytest.mark.asyncio +async def test_delete_authentication_method_network_error(mocker): + client = MyAccountClient(domain="auth0.local") + mocker.patch( + "httpx.AsyncClient.delete", + new_callable=AsyncMock, + side_effect=Exception("Connection reset"), + ) + + with pytest.raises(ApiError): + await client.delete_authentication_method( + access_token="token123", authentication_method_id="am_1" + ) + + +@pytest.mark.asyncio +async def test_update_authentication_method_api_error(mocker): + client = MyAccountClient(domain="auth0.local") + response = AsyncMock() + response.status_code = 422 + response.json = MagicMock(return_value={ + "title": "Unprocessable", "type": "validation_error", "detail": "Invalid", "status": 422, + }) + mocker.patch("httpx.AsyncClient.patch", new_callable=AsyncMock, return_value=response) + + req = UpdateAuthenticationMethodRequest(name="x") + with pytest.raises(MyAccountApiError) as exc: + await client.update_authentication_method( + access_token="token123", authentication_method_id="am_1", request=req + ) + assert exc.value.status == 422 + + +@pytest.mark.asyncio +async def test_update_authentication_method_network_error(mocker): + client = MyAccountClient(domain="auth0.local") + mocker.patch( + "httpx.AsyncClient.patch", new_callable=AsyncMock, side_effect=Exception("timeout") + ) + + req = UpdateAuthenticationMethodRequest(name="x") + with pytest.raises(ApiError): + await client.update_authentication_method( + access_token="token123", authentication_method_id="am_1", request=req + ) + + +@pytest.mark.asyncio +async def test_enroll_authentication_method_api_error(mocker): + client = MyAccountClient(domain="auth0.local") + response = AsyncMock() + response.status_code = 403 + response.json = MagicMock(return_value={ + "title": "Forbidden", "type": "forbidden", "detail": "Scope missing", "status": 403, + }) + mocker.patch("httpx.AsyncClient.post", new_callable=AsyncMock, return_value=response) + + req = EnrollAuthenticationMethodRequest(type="passkey") + with pytest.raises(MyAccountApiError) as exc: + await client.enroll_authentication_method(access_token="token123", request=req) + assert exc.value.status == 403 + + +@pytest.mark.asyncio +async def test_enroll_authentication_method_network_error(mocker): + client = MyAccountClient(domain="auth0.local") + mocker.patch( + "httpx.AsyncClient.post", + new_callable=AsyncMock, + side_effect=Exception("Connection refused"), + ) + + req = EnrollAuthenticationMethodRequest(type="passkey") + with pytest.raises(ApiError): + await client.enroll_authentication_method(access_token="token123", request=req) + + +@pytest.mark.asyncio +async def test_verify_authentication_method_api_error(mocker): + client = MyAccountClient(domain="auth0.local") + response = AsyncMock() + response.status_code = 400 + response.json = MagicMock(return_value={ + "title": "Bad Request", "type": "invalid_request", "detail": "Invalid OTP", "status": 400, + }) + mocker.patch("httpx.AsyncClient.post", new_callable=AsyncMock, return_value=response) + + req = VerifyAuthenticationMethodRequest(auth_session="session_abc", otp_code="000000") + with pytest.raises(MyAccountApiError) as exc: + await client.verify_authentication_method( + access_token="token123", authentication_method_id="am_1", request=req + ) + assert exc.value.status == 400 + + +@pytest.mark.asyncio +async def test_verify_authentication_method_network_error(mocker): + client = MyAccountClient(domain="auth0.local") + mocker.patch( + "httpx.AsyncClient.post", + new_callable=AsyncMock, + side_effect=Exception("Connection refused"), + ) + + req = VerifyAuthenticationMethodRequest(auth_session="session_abc", otp_code="123456") + with pytest.raises(ApiError): + await client.verify_authentication_method( + access_token="token123", authentication_method_id="am_1", request=req + ) + + +@pytest.mark.asyncio +async def test_enroll_authentication_method_location_collection_url(mocker): + client = MyAccountClient(domain="auth0.local") + response = AsyncMock() + response.status_code = 202 + response.headers = {"location": "/me/v1/authentication-methods/"} + response.json = MagicMock(return_value={"auth_session": "session_abc"}) + mocker.patch("httpx.AsyncClient.post", new_callable=AsyncMock, return_value=response) + + req = EnrollAuthenticationMethodRequest(type="passkey") + with pytest.raises(ApiError) as exc: + await client.enroll_authentication_method(access_token="token123", request=req) + assert "could not extract ID" in str(exc.value) + + +# ============================================================================= +# DPoP nonce retry (RFC 9449 §8.2) — tests DPoPAuth.auth_flow directly +# ============================================================================= + + +def test_dpop_auth_flow_retries_with_nonce_on_401(): + """ + DPoPAuth.auth_flow() must retry with DPoP-Nonce when server responds 401 + + DPoP-Nonce header (RFC 9449 §8.2). Tested by driving the generator directly. + """ + import base64 + import json as _json + + dpop_key = jwk_module.JWK.generate(kty="EC", crv="P-256") + auth = DPoPAuth(token="test_access_token", key=dpop_key) + + request = httpx.Request("GET", "https://auth0.local/me/v1/factors") + flow = auth.auth_flow(request) + + # First yield — initial request + first_request = next(flow) + assert "DPoP" in first_request.headers + assert "Authorization" in first_request.headers + + # First proof must not have nonce + proof1 = first_request.headers["DPoP"] + payload1_b64 = proof1.split(".")[1] + padding = 4 - len(payload1_b64) % 4 + payload1 = _json.loads(base64.urlsafe_b64decode(payload1_b64 + "=" * padding)) + assert "nonce" not in payload1 + + # Simulate 401 + DPoP-Nonce response + nonce_response = httpx.Response( + status_code=401, + headers={"DPoP-Nonce": "server-nonce-abc"}, + content=b'{"error":"use_dpop_nonce"}', + request=request, + ) + + # Second yield — retry request with nonce + try: + second_request = flow.send(nonce_response) + except StopIteration: + second_request = None + + assert second_request is not None + proof2 = second_request.headers["DPoP"] + payload2_b64 = proof2.split(".")[1] + padding = 4 - len(payload2_b64) % 4 + payload2 = _json.loads(base64.urlsafe_b64decode(payload2_b64 + "=" * padding)) + assert payload2["nonce"] == "server-nonce-abc" + + +def test_dpop_auth_flow_no_retry_on_non_401(): + """DPoPAuth.auth_flow() must NOT retry when the response is not 401.""" + dpop_key = jwk_module.JWK.generate(kty="EC", crv="P-256") + auth = DPoPAuth(token="test_access_token", key=dpop_key) + + request = httpx.Request("GET", "https://auth0.local/me/v1/factors") + flow = auth.auth_flow(request) + next(flow) + + success_response = httpx.Response( + status_code=200, + content=b'{"factors":[]}', + request=request, + ) + + try: + flow.send(success_response) + retried = True + except StopIteration: + retried = False + + assert not retried + diff --git a/src/auth0_server_python/tests/test_passkey_my_account.py b/src/auth0_server_python/tests/test_passkey_my_account.py deleted file mode 100644 index 4b7f29d..0000000 --- a/src/auth0_server_python/tests/test_passkey_my_account.py +++ /dev/null @@ -1,830 +0,0 @@ -from unittest.mock import AsyncMock, MagicMock - -import pytest -from jwcrypto import jwk as jwk_module - -from auth0_server_python.auth_schemes.dpop_auth import DPoPAuth -from auth0_server_python.auth_server.my_account_client import MyAccountClient -from auth0_server_python.auth_types import ( - AuthenticationMethod, - EnrollAuthenticationMethodRequest, - EnrollmentChallengeResponse, - GetFactorsResponse, - ListAuthenticationMethodsResponse, - PasskeyAuthResponse, - UpdateAuthenticationMethodRequest, - VerifyAuthenticationMethodRequest, -) -from auth0_server_python.error import ApiError, MissingRequiredArgumentError, MyAccountApiError - - -@pytest.mark.asyncio -async def test_get_factors_success(mocker): - client = MyAccountClient(domain="auth0.local") - response = AsyncMock() - response.status_code = 200 - response.json = MagicMock(return_value={"factors": [{"name": "sms", "enabled": True}]}) - mocker.patch("httpx.AsyncClient.get", new_callable=AsyncMock, return_value=response) - - result = await client.get_factors(access_token="token123") - - assert isinstance(result, GetFactorsResponse) - assert len(result.factors) == 1 - assert result.factors[0].name == "sms" - assert result.factors[0].enabled is True - - -@pytest.mark.asyncio -@pytest.mark.parametrize("access_token", [None, ""]) -async def test_get_factors_missing_access_token(mocker, access_token): - client = MyAccountClient(domain="auth0.local") - mock_get = mocker.patch("httpx.AsyncClient.get", new_callable=AsyncMock) - - with pytest.raises(MissingRequiredArgumentError): - await client.get_factors(access_token=access_token) - - mock_get.assert_not_awaited() - - -@pytest.mark.asyncio -async def test_get_factors_api_error(mocker): - client = MyAccountClient(domain="auth0.local") - response = AsyncMock() - response.status_code = 403 - response.json = MagicMock( - return_value={ - "title": "Forbidden", - "type": "forbidden", - "detail": "Insufficient scope", - "status": 403, - } - ) - mocker.patch("httpx.AsyncClient.get", new_callable=AsyncMock, return_value=response) - - with pytest.raises(MyAccountApiError) as exc: - await client.get_factors(access_token="token123") - - assert exc.value.status == 403 - - -@pytest.mark.asyncio -async def test_get_factors_network_error(mocker): - client = MyAccountClient(domain="auth0.local") - mocker.patch( - "httpx.AsyncClient.get", new_callable=AsyncMock, side_effect=Exception("Connection refused") - ) - - with pytest.raises(ApiError): - await client.get_factors(access_token="token123") - - -@pytest.mark.asyncio -async def test_get_factors_empty_list(mocker): - client = MyAccountClient(domain="auth0.local") - response = AsyncMock() - response.status_code = 200 - response.json = MagicMock(return_value={"factors": []}) - mocker.patch("httpx.AsyncClient.get", new_callable=AsyncMock, return_value=response) - - result = await client.get_factors(access_token="token123") - assert result.factors == [] - - -@pytest.mark.asyncio -async def test_get_factors_extra_fields(mocker): - client = MyAccountClient(domain="auth0.local") - response = AsyncMock() - response.status_code = 200 - response.json = MagicMock( - return_value={ - "factors": [{"name": "webauthn-roaming", "enabled": True, "future_field": "value"}] - } - ) - mocker.patch("httpx.AsyncClient.get", new_callable=AsyncMock, return_value=response) - - result = await client.get_factors(access_token="token123") - assert result.factors[0].name == "webauthn-roaming" - - -@pytest.mark.asyncio -async def test_list_authentication_methods_success(mocker): - client = MyAccountClient(domain="auth0.local") - response = AsyncMock() - response.status_code = 200 - response.json = MagicMock( - return_value={ - "authentication_methods": [ - { - "id": "am_1", - "type": "passkey", - "created_at": "2026-01-01T00:00:00Z", - "key_id": "kid1", - } - ] - } - ) - mocker.patch("httpx.AsyncClient.get", new_callable=AsyncMock, return_value=response) - - result = await client.list_authentication_methods(access_token="token123") - assert isinstance(result, ListAuthenticationMethodsResponse) - assert len(result.authentication_methods) == 1 - assert result.authentication_methods[0].type == "passkey" - - -@pytest.mark.asyncio -async def test_list_authentication_methods_with_type_filter(mocker): - client = MyAccountClient(domain="auth0.local") - response = AsyncMock() - response.status_code = 200 - response.json = MagicMock(return_value={"authentication_methods": []}) - mock_get = mocker.patch("httpx.AsyncClient.get", new_callable=AsyncMock, return_value=response) - - await client.list_authentication_methods(access_token="token123", type_filter="passkey") - mock_get.assert_awaited_once() - call_kwargs = mock_get.call_args[1] - assert call_kwargs["params"] == {"type": "passkey"} - - -@pytest.mark.asyncio -async def test_list_authentication_methods_empty(mocker): - client = MyAccountClient(domain="auth0.local") - response = AsyncMock() - response.status_code = 200 - response.json = MagicMock(return_value={"authentication_methods": []}) - mocker.patch("httpx.AsyncClient.get", new_callable=AsyncMock, return_value=response) - - result = await client.list_authentication_methods(access_token="token123") - assert result.authentication_methods == [] - - -@pytest.mark.asyncio -async def test_get_authentication_method_success(mocker): - client = MyAccountClient(domain="auth0.local") - response = AsyncMock() - response.status_code = 200 - response.json = MagicMock( - return_value={"id": "am_1", "type": "passkey", "created_at": "2026-01-01T00:00:00Z"} - ) - mocker.patch("httpx.AsyncClient.get", new_callable=AsyncMock, return_value=response) - - result = await client.get_authentication_method( - access_token="token123", authentication_method_id="am_1" - ) - assert isinstance(result, AuthenticationMethod) - assert result.id == "am_1" - - -@pytest.mark.asyncio -@pytest.mark.parametrize("method_id", [None, ""]) -async def test_get_authentication_method_missing_id(mocker, method_id): - client = MyAccountClient(domain="auth0.local") - mock_get = mocker.patch("httpx.AsyncClient.get", new_callable=AsyncMock) - - with pytest.raises(MissingRequiredArgumentError): - await client.get_authentication_method( - access_token="token123", authentication_method_id=method_id - ) - - mock_get.assert_not_awaited() - - -@pytest.mark.asyncio -async def test_get_authentication_method_path_traversal(mocker): - client = MyAccountClient(domain="auth0.local") - response = AsyncMock() - response.status_code = 200 - response.json = MagicMock( - return_value={"id": "id/slash", "type": "passkey", "created_at": "2026-01-01T00:00:00Z"} - ) - mock_get = mocker.patch("httpx.AsyncClient.get", new_callable=AsyncMock, return_value=response) - - await client.get_authentication_method( - access_token="token123", authentication_method_id="id/slash" - ) - call_url = mock_get.call_args[1]["url"] - assert "id%2Fslash" in call_url - assert "id/slash" not in call_url.replace("https://auth0.local/me/", "") - - -@pytest.mark.asyncio -async def test_get_authentication_method_pipe_encoding(mocker): - client = MyAccountClient(domain="auth0.local") - response = AsyncMock() - response.status_code = 200 - response.json = MagicMock( - return_value={"id": "passkey|new", "type": "passkey", "created_at": "2026-01-01T00:00:00Z"} - ) - mock_get = mocker.patch("httpx.AsyncClient.get", new_callable=AsyncMock, return_value=response) - - await client.get_authentication_method( - access_token="token123", authentication_method_id="passkey|new" - ) - call_url = mock_get.call_args[1]["url"] - assert "passkey%7Cnew" in call_url - - -@pytest.mark.asyncio -async def test_delete_authentication_method_success(mocker): - client = MyAccountClient(domain="auth0.local") - response = AsyncMock() - response.status_code = 204 - mocker.patch("httpx.AsyncClient.delete", new_callable=AsyncMock, return_value=response) - - result = await client.delete_authentication_method( - access_token="token123", authentication_method_id="am_1" - ) - assert result is None - - -@pytest.mark.asyncio -@pytest.mark.parametrize("method_id", [None, ""]) -async def test_delete_authentication_method_missing_id(mocker, method_id): - client = MyAccountClient(domain="auth0.local") - mock_delete = mocker.patch("httpx.AsyncClient.delete", new_callable=AsyncMock) - - with pytest.raises(MissingRequiredArgumentError): - await client.delete_authentication_method( - access_token="token123", authentication_method_id=method_id - ) - - mock_delete.assert_not_awaited() - - -@pytest.mark.asyncio -async def test_update_authentication_method_success(mocker): - client = MyAccountClient(domain="auth0.local") - response = AsyncMock() - response.status_code = 200 - response.json = MagicMock( - return_value={ - "id": "am_1", - "type": "passkey", - "created_at": "2026-01-01T00:00:00Z", - "name": "My Key", - } - ) - mock_patch = mocker.patch( - "httpx.AsyncClient.patch", new_callable=AsyncMock, return_value=response - ) - - req = UpdateAuthenticationMethodRequest(name="My Key") - result = await client.update_authentication_method( - access_token="token123", authentication_method_id="am_1", request=req - ) - assert result.name == "My Key" - call_kwargs = mock_patch.call_args[1] - assert call_kwargs["json"] == {"name": "My Key"} - - -@pytest.mark.asyncio -async def test_update_authentication_method_missing_request(mocker): - client = MyAccountClient(domain="auth0.local") - mocker.patch("httpx.AsyncClient.patch", new_callable=AsyncMock) - - with pytest.raises(MissingRequiredArgumentError): - await client.update_authentication_method( - access_token="token123", authentication_method_id="am_1", request=None - ) - - -@pytest.mark.asyncio -async def test_enroll_authentication_method_success(mocker): - client = MyAccountClient(domain="auth0.local") - response = AsyncMock() - response.status_code = 201 - response.headers = {"location": "/me/v1/authentication-methods/passkey|new"} - response.json = MagicMock( - return_value={ - "auth_session": "session_abc", - "authn_params_public_key": { - "challenge": "dGVzdA", - "rp": {"id": "auth0.local", "name": "My App"}, - "user": {"id": "dXNlcl8x", "name": "user@test.com", "displayName": "Test User"}, - "pubKeyCredParams": [{"type": "public-key", "alg": -7}], - "authenticatorSelection": { - "residentKey": "required", - "userVerification": "preferred", - }, - "timeout": 60000, - }, - } - ) - mocker.patch("httpx.AsyncClient.post", new_callable=AsyncMock, return_value=response) - - req = EnrollAuthenticationMethodRequest(type="passkey") - result = await client.enroll_authentication_method(access_token="token123", request=req) - - assert isinstance(result, EnrollmentChallengeResponse) - assert result.authentication_method_id == "passkey|new" - assert result.auth_session == "session_abc" - assert result.authn_params_public_key is not None - assert result.authn_params_public_key.pub_key_cred_params[0].alg == -7 - assert result.authn_params_public_key.authenticator_selection.resident_key == "required" - assert result.authn_params_public_key.user.display_name == "Test User" - - -@pytest.mark.asyncio -async def test_enroll_authentication_method_missing_location(mocker): - client = MyAccountClient(domain="auth0.local") - response = AsyncMock() - response.status_code = 201 - response.headers = {} - response.json = MagicMock(return_value={"auth_session": "session_abc"}) - mocker.patch("httpx.AsyncClient.post", new_callable=AsyncMock, return_value=response) - - req = EnrollAuthenticationMethodRequest(type="passkey") - with pytest.raises(ApiError) as exc: - await client.enroll_authentication_method(access_token="token123", request=req) - - assert "Location header" in str(exc.value) - - -@pytest.mark.asyncio -async def test_enroll_authentication_method_location_with_query(mocker): - client = MyAccountClient(domain="auth0.local") - response = AsyncMock() - response.status_code = 201 - response.headers = {"location": "/me/v1/authentication-methods/abc123?tracking=1"} - response.json = MagicMock(return_value={"auth_session": "session_abc"}) - mocker.patch("httpx.AsyncClient.post", new_callable=AsyncMock, return_value=response) - - req = EnrollAuthenticationMethodRequest(type="passkey") - result = await client.enroll_authentication_method(access_token="token123", request=req) - assert result.authentication_method_id == "abc123" - - -@pytest.mark.asyncio -async def test_enroll_authentication_method_location_absolute_url(mocker): - client = MyAccountClient(domain="auth0.local") - response = AsyncMock() - response.status_code = 201 - response.headers = {"location": "https://tenant.auth0.com/me/v1/authentication-methods/am_xyz"} - response.json = MagicMock(return_value={"auth_session": "session_abc"}) - mocker.patch("httpx.AsyncClient.post", new_callable=AsyncMock, return_value=response) - - req = EnrollAuthenticationMethodRequest(type="passkey") - result = await client.enroll_authentication_method(access_token="token123", request=req) - assert result.authentication_method_id == "am_xyz" - - -@pytest.mark.asyncio -async def test_verify_authentication_method_success(mocker): - client = MyAccountClient(domain="auth0.local") - response = AsyncMock() - response.status_code = 200 - response.json = MagicMock( - return_value={ - "id": "am_1", - "type": "passkey", - "created_at": "2026-01-01T00:00:00Z", - "confirmed": True, - } - ) - mock_post = mocker.patch( - "httpx.AsyncClient.post", new_callable=AsyncMock, return_value=response - ) - - authn_response = PasskeyAuthResponse( - id="cred1", - raw_id="cmF3MQ", - type="public-key", - authenticator_attachment="platform", - response={"clientDataJSON": "abc", "attestationObject": "def"}, - ) - req = VerifyAuthenticationMethodRequest( - auth_session="session_abc", authn_response=authn_response - ) - result = await client.verify_authentication_method( - access_token="token123", authentication_method_id="passkey|new", request=req - ) - - assert isinstance(result, AuthenticationMethod) - assert result.confirmed is True - - call_kwargs = mock_post.call_args[1] - body = call_kwargs["json"] - assert "rawId" in body["authn_response"] - assert "raw_id" not in body["authn_response"] - assert "authenticatorAttachment" in body["authn_response"] - assert body["auth_session"] == "session_abc" - assert "passkey%7Cnew" in call_kwargs["url"] - - -@pytest.mark.asyncio -@pytest.mark.parametrize("method_id", [None, ""]) -async def test_verify_authentication_method_missing_id(mocker, method_id): - client = MyAccountClient(domain="auth0.local") - mocker.patch("httpx.AsyncClient.post", new_callable=AsyncMock) - - req = VerifyAuthenticationMethodRequest(auth_session="session_abc", otp_code="123456") - with pytest.raises(MissingRequiredArgumentError): - await client.verify_authentication_method( - access_token="token123", authentication_method_id=method_id, request=req - ) - - -@pytest.mark.asyncio -async def test_enrollment_challenge_response_repr(): - resp = EnrollmentChallengeResponse( - authentication_method_id="am_1", - auth_session="super_secret_session", - authn_params_public_key=None, - ) - repr_str = repr(resp) - assert "super_secret_session" not in repr_str - assert "[REDACTED]" in repr_str - assert "am_1" in repr_str - - -def test_verify_request_requires_at_least_one_method(): - with pytest.raises(Exception, match="At least one verification method"): - VerifyAuthenticationMethodRequest(auth_session="session_abc") - - -def test_verify_request_accepts_otp_code(): - req = VerifyAuthenticationMethodRequest(auth_session="session_abc", otp_code="123456") - assert req.otp_code == "123456" - - -def test_verify_request_accepts_authn_response(): - authn_resp = PasskeyAuthResponse( - id="cred1", - raw_id="cmF3MQ", - type="public-key", - response={"clientDataJSON": "abc", "attestationObject": "def"}, - ) - req = VerifyAuthenticationMethodRequest(auth_session="session_abc", authn_response=authn_resp) - assert req.authn_response is not None - - -@pytest.mark.asyncio -async def test_get_factors_with_dpop_key(mocker): - client = MyAccountClient(domain="auth0.local") - response = AsyncMock() - response.status_code = 200 - response.json = MagicMock(return_value={"factors": [{"name": "sms", "enabled": True}]}) - mock_get = mocker.patch("httpx.AsyncClient.get", new_callable=AsyncMock, return_value=response) - - dpop_key = jwk_module.JWK.generate(kty="EC", crv="P-256") - await client.get_factors(access_token="token123", dpop_key=dpop_key) - - mock_get.assert_awaited_once() - call_kwargs = mock_get.call_args[1] - assert isinstance(call_kwargs["auth"], DPoPAuth) - - -# ============================================================================= -# DPoP integration(mock) tests -# ============================================================================= - - -@pytest.mark.asyncio -async def test_list_authentication_methods_with_dpop_key(mocker): - client = MyAccountClient(domain="auth0.local") - response = AsyncMock() - response.status_code = 200 - response.json = MagicMock(return_value={"authentication_methods": []}) - mock_get = mocker.patch("httpx.AsyncClient.get", new_callable=AsyncMock, return_value=response) - - dpop_key = jwk_module.JWK.generate(kty="EC", crv="P-256") - await client.list_authentication_methods(access_token="token123", dpop_key=dpop_key) - - assert isinstance(mock_get.call_args[1]["auth"], DPoPAuth) - - -@pytest.mark.asyncio -async def test_get_authentication_method_with_dpop_key(mocker): - client = MyAccountClient(domain="auth0.local") - response = AsyncMock() - response.status_code = 200 - response.json = MagicMock( - return_value={"id": "am_1", "type": "passkey", "created_at": "2026-01-01T00:00:00Z"} - ) - mock_get = mocker.patch("httpx.AsyncClient.get", new_callable=AsyncMock, return_value=response) - - dpop_key = jwk_module.JWK.generate(kty="EC", crv="P-256") - await client.get_authentication_method( - access_token="token123", authentication_method_id="am_1", dpop_key=dpop_key - ) - - assert isinstance(mock_get.call_args[1]["auth"], DPoPAuth) - - -@pytest.mark.asyncio -async def test_delete_authentication_method_with_dpop_key(mocker): - client = MyAccountClient(domain="auth0.local") - response = AsyncMock() - response.status_code = 204 - mock_delete = mocker.patch( - "httpx.AsyncClient.delete", new_callable=AsyncMock, return_value=response - ) - - dpop_key = jwk_module.JWK.generate(kty="EC", crv="P-256") - await client.delete_authentication_method( - access_token="token123", authentication_method_id="am_1", dpop_key=dpop_key - ) - - assert isinstance(mock_delete.call_args[1]["auth"], DPoPAuth) - - -@pytest.mark.asyncio -async def test_update_authentication_method_with_dpop_key(mocker): - client = MyAccountClient(domain="auth0.local") - response = AsyncMock() - response.status_code = 200 - response.json = MagicMock( - return_value={"id": "am_1", "type": "passkey", "created_at": "2026-01-01T00:00:00Z"} - ) - mock_patch = mocker.patch( - "httpx.AsyncClient.patch", new_callable=AsyncMock, return_value=response - ) - - dpop_key = jwk_module.JWK.generate(kty="EC", crv="P-256") - req = UpdateAuthenticationMethodRequest(name="New Name") - await client.update_authentication_method( - access_token="token123", authentication_method_id="am_1", request=req, dpop_key=dpop_key - ) - - assert isinstance(mock_patch.call_args[1]["auth"], DPoPAuth) - - -@pytest.mark.asyncio -async def test_enroll_authentication_method_with_dpop_key(mocker): - client = MyAccountClient(domain="auth0.local") - response = AsyncMock() - response.status_code = 201 - response.headers = {"location": "/me/v1/authentication-methods/passkey|new"} - response.json = MagicMock(return_value={"auth_session": "session_abc"}) - mock_post = mocker.patch( - "httpx.AsyncClient.post", new_callable=AsyncMock, return_value=response - ) - - dpop_key = jwk_module.JWK.generate(kty="EC", crv="P-256") - req = EnrollAuthenticationMethodRequest(type="passkey") - await client.enroll_authentication_method( - access_token="token123", request=req, dpop_key=dpop_key - ) - - assert isinstance(mock_post.call_args[1]["auth"], DPoPAuth) - - -@pytest.mark.asyncio -async def test_verify_authentication_method_with_dpop_key(mocker): - client = MyAccountClient(domain="auth0.local") - response = AsyncMock() - response.status_code = 200 - response.json = MagicMock( - return_value={"id": "am_1", "type": "passkey", "created_at": "2026-01-01T00:00:00Z"} - ) - mock_post = mocker.patch( - "httpx.AsyncClient.post", new_callable=AsyncMock, return_value=response - ) - - dpop_key = jwk_module.JWK.generate(kty="EC", crv="P-256") - req = VerifyAuthenticationMethodRequest(auth_session="session_abc", otp_code="123456") - await client.verify_authentication_method( - access_token="token123", - authentication_method_id="am_1", - request=req, - dpop_key=dpop_key, - ) - - assert isinstance(mock_post.call_args[1]["auth"], DPoPAuth) - - -# ============================================================================= -# API error and network error tests -# ============================================================================= - - -@pytest.mark.asyncio -async def test_list_authentication_methods_api_error(mocker): - client = MyAccountClient(domain="auth0.local") - response = AsyncMock() - response.status_code = 403 - response.json = MagicMock( - return_value={ - "title": "Forbidden", - "type": "forbidden", - "detail": "Insufficient scope", - "status": 403, - } - ) - mocker.patch("httpx.AsyncClient.get", new_callable=AsyncMock, return_value=response) - - with pytest.raises(MyAccountApiError) as exc: - await client.list_authentication_methods(access_token="token123") - assert exc.value.status == 403 - - -@pytest.mark.asyncio -async def test_list_authentication_methods_network_error(mocker): - client = MyAccountClient(domain="auth0.local") - mocker.patch( - "httpx.AsyncClient.get", new_callable=AsyncMock, side_effect=Exception("Connection refused") - ) - - with pytest.raises(ApiError): - await client.list_authentication_methods(access_token="token123") - - -@pytest.mark.asyncio -async def test_get_authentication_method_api_error(mocker): - client = MyAccountClient(domain="auth0.local") - response = AsyncMock() - response.status_code = 404 - response.json = MagicMock( - return_value={ - "title": "Not Found", - "type": "not_found", - "detail": "Not found", - "status": 404, - } - ) - mocker.patch("httpx.AsyncClient.get", new_callable=AsyncMock, return_value=response) - - with pytest.raises(MyAccountApiError) as exc: - await client.get_authentication_method( - access_token="token123", authentication_method_id="am_1" - ) - assert exc.value.status == 404 - - -@pytest.mark.asyncio -async def test_get_authentication_method_network_error(mocker): - client = MyAccountClient(domain="auth0.local") - mocker.patch("httpx.AsyncClient.get", new_callable=AsyncMock, side_effect=Exception("timeout")) - - with pytest.raises(ApiError): - await client.get_authentication_method( - access_token="token123", authentication_method_id="am_1" - ) - - -@pytest.mark.asyncio -async def test_delete_authentication_method_api_error(mocker): - client = MyAccountClient(domain="auth0.local") - response = AsyncMock() - response.status_code = 404 - response.json = MagicMock( - return_value={ - "title": "Not Found", - "type": "not_found", - "detail": "Not found", - "status": 404, - } - ) - mocker.patch("httpx.AsyncClient.delete", new_callable=AsyncMock, return_value=response) - - with pytest.raises(MyAccountApiError) as exc: - await client.delete_authentication_method( - access_token="token123", authentication_method_id="am_1" - ) - assert exc.value.status == 404 - - -@pytest.mark.asyncio -async def test_delete_authentication_method_network_error(mocker): - client = MyAccountClient(domain="auth0.local") - mocker.patch( - "httpx.AsyncClient.delete", - new_callable=AsyncMock, - side_effect=Exception("Connection reset"), - ) - - with pytest.raises(ApiError): - await client.delete_authentication_method( - access_token="token123", authentication_method_id="am_1" - ) - - -@pytest.mark.asyncio -async def test_update_authentication_method_api_error(mocker): - client = MyAccountClient(domain="auth0.local") - response = AsyncMock() - response.status_code = 422 - response.json = MagicMock( - return_value={ - "title": "Unprocessable", - "type": "validation_error", - "detail": "Invalid", - "status": 422, - } - ) - mocker.patch("httpx.AsyncClient.patch", new_callable=AsyncMock, return_value=response) - - req = UpdateAuthenticationMethodRequest(name="x") - with pytest.raises(MyAccountApiError) as exc: - await client.update_authentication_method( - access_token="token123", authentication_method_id="am_1", request=req - ) - assert exc.value.status == 422 - - -@pytest.mark.asyncio -async def test_update_authentication_method_network_error(mocker): - client = MyAccountClient(domain="auth0.local") - mocker.patch( - "httpx.AsyncClient.patch", new_callable=AsyncMock, side_effect=Exception("timeout") - ) - - req = UpdateAuthenticationMethodRequest(name="x") - with pytest.raises(ApiError): - await client.update_authentication_method( - access_token="token123", authentication_method_id="am_1", request=req - ) - - -@pytest.mark.asyncio -async def test_enroll_authentication_method_api_error(mocker): - client = MyAccountClient(domain="auth0.local") - response = AsyncMock() - response.status_code = 403 - response.json = MagicMock( - return_value={ - "title": "Forbidden", - "type": "forbidden", - "detail": "Scope missing", - "status": 403, - } - ) - mocker.patch("httpx.AsyncClient.post", new_callable=AsyncMock, return_value=response) - - req = EnrollAuthenticationMethodRequest(type="passkey") - with pytest.raises(MyAccountApiError) as exc: - await client.enroll_authentication_method(access_token="token123", request=req) - assert exc.value.status == 403 - - -@pytest.mark.asyncio -async def test_enroll_authentication_method_network_error(mocker): - client = MyAccountClient(domain="auth0.local") - mocker.patch( - "httpx.AsyncClient.post", - new_callable=AsyncMock, - side_effect=Exception("Connection refused"), - ) - - req = EnrollAuthenticationMethodRequest(type="passkey") - with pytest.raises(ApiError): - await client.enroll_authentication_method(access_token="token123", request=req) - - -@pytest.mark.asyncio -async def test_verify_authentication_method_api_error(mocker): - client = MyAccountClient(domain="auth0.local") - response = AsyncMock() - response.status_code = 400 - response.json = MagicMock( - return_value={ - "title": "Bad Request", - "type": "invalid_request", - "detail": "Invalid OTP", - "status": 400, - } - ) - mocker.patch("httpx.AsyncClient.post", new_callable=AsyncMock, return_value=response) - - req = VerifyAuthenticationMethodRequest(auth_session="session_abc", otp_code="000000") - with pytest.raises(MyAccountApiError) as exc: - await client.verify_authentication_method( - access_token="token123", authentication_method_id="am_1", request=req - ) - assert exc.value.status == 400 - - -@pytest.mark.asyncio -async def test_verify_authentication_method_network_error(mocker): - client = MyAccountClient(domain="auth0.local") - mocker.patch( - "httpx.AsyncClient.post", - new_callable=AsyncMock, - side_effect=Exception("Connection refused"), - ) - - req = VerifyAuthenticationMethodRequest(auth_session="session_abc", otp_code="123456") - with pytest.raises(ApiError): - await client.verify_authentication_method( - access_token="token123", authentication_method_id="am_1", request=req - ) - - -# ============================================================================= -# Location header extraction edge case -# ============================================================================= - - -@pytest.mark.asyncio -async def test_enroll_authentication_method_location_collection_url(mocker): - """Rejects Location header that ends at collection path without resource ID.""" - client = MyAccountClient(domain="auth0.local") - response = AsyncMock() - response.status_code = 201 - response.headers = {"location": "/me/v1/authentication-methods/"} - response.json = MagicMock(return_value={"auth_session": "session_abc"}) - mocker.patch("httpx.AsyncClient.post", new_callable=AsyncMock, return_value=response) - - req = EnrollAuthenticationMethodRequest(type="passkey") - with pytest.raises(ApiError) as exc: - await client.enroll_authentication_method(access_token="token123", request=req) - assert "could not extract ID" in str(exc.value) diff --git a/src/auth0_server_python/tests/test_passkey_server_client.py b/src/auth0_server_python/tests/test_passkey_server_client.py deleted file mode 100644 index 7c2be37..0000000 --- a/src/auth0_server_python/tests/test_passkey_server_client.py +++ /dev/null @@ -1,585 +0,0 @@ -import time -from unittest.mock import AsyncMock - -import httpx -import pytest - -from auth0_server_python.auth_server.server_client import ServerClient -from auth0_server_python.auth_types import ( - PasskeyAuthResponse, - PasskeyLoginChallengeResponse, - PasskeySignupChallengeResponse, - PasskeyTokenResponse, -) -from auth0_server_python.error import ApiError, MissingRequiredArgumentError - - -@pytest.fixture -def server_client(): - return ServerClient( - domain="auth0.local", - client_id="test_client_id", - client_secret="test_client_secret", - state_store=AsyncMock(), - transaction_store=AsyncMock(), - secret="test-secret-value", - ) - - -SIGNUP_CHALLENGE_RESPONSE = { - "auth_session": "session_abc123", - "authn_params_public_key": { - "challenge": "dGVzdC1jaGFsbGVuZ2U", - "rp": {"id": "auth0.local", "name": "Test App"}, - "user": {"id": "dXNlcl8x", "name": "user@example.com", "displayName": "Jane"}, - "pubKeyCredParams": [{"type": "public-key", "alg": -7}], - "authenticatorSelection": { - "residentKey": "required", - "userVerification": "preferred", - }, - "timeout": 60000, - }, -} - -LOGIN_CHALLENGE_RESPONSE = { - "auth_session": "session_login_xyz", - "authn_params_public_key": { - "challenge": "bG9naW4tY2hhbGxlbmdl", - "rpId": "auth0.local", - "timeout": 60000, - "userVerification": "preferred", - }, -} - -TOKEN_RESPONSE = { - "access_token": "at_passkey_123", - "id_token": "eyJ.test.jwt", - "token_type": "Bearer", - "expires_in": 86400, - "scope": "openid profile", -} - - -def _mock_response(status_code=200, json_data=None, headers=None): - resp = httpx.Response( - status_code=status_code, - json=json_data, - headers=headers or {}, - request=httpx.Request("POST", "https://auth0.local/passkey/register"), - ) - return resp - - -# ============================================================================= -# passkey_signup_challenge -# ============================================================================= - - -@pytest.mark.asyncio -async def test_passkey_signup_challenge_success(server_client, mocker): - mock_response = _mock_response(200, SIGNUP_CHALLENGE_RESPONSE) - mock_client = AsyncMock() - mock_client.post = AsyncMock(return_value=mock_response) - mock_client.__aenter__ = AsyncMock(return_value=mock_client) - mock_client.__aexit__ = AsyncMock(return_value=None) - mocker.patch.object(server_client, "_get_http_client", return_value=mock_client) - - result = await server_client.passkey_signup_challenge( - email="user@example.com", - name="Jane Doe", - connection="Username-Password-Authentication", - ) - - assert isinstance(result, PasskeySignupChallengeResponse) - assert result.auth_session == "session_abc123" - assert result.authn_params_public_key.challenge == "dGVzdC1jaGFsbGVuZ2U" - assert result.authn_params_public_key.rp.id == "auth0.local" - assert result.authn_params_public_key.user.display_name == "Jane" - assert result.authn_params_public_key.pub_key_cred_params[0].alg == -7 - assert result.authn_params_public_key.authenticator_selection.resident_key == "required" - - call_args = mock_client.post.call_args - assert "/passkey/register" in call_args.args[0] - body = call_args.kwargs["json"] - assert body["client_id"] == "test_client_id" - assert body["client_secret"] == "test_client_secret" - assert body["user_profile"]["email"] == "user@example.com" - assert body["user_profile"]["name"] == "Jane Doe" - assert body["realm"] == "Username-Password-Authentication" - - -@pytest.mark.asyncio -async def test_passkey_signup_challenge_user_profile_fields(server_client, mocker): - mock_response = _mock_response(200, SIGNUP_CHALLENGE_RESPONSE) - mock_client = AsyncMock() - mock_client.post = AsyncMock(return_value=mock_response) - mock_client.__aenter__ = AsyncMock(return_value=mock_client) - mock_client.__aexit__ = AsyncMock(return_value=None) - mocker.patch.object(server_client, "_get_http_client", return_value=mock_client) - - await server_client.passkey_signup_challenge( - email="u@e.com", - username="jdoe", - phone_number="+1234567890", - given_name="Jane", - family_name="Doe", - nickname="jd", - picture="https://example.com/pic.jpg", - user_metadata={"role": "admin"}, - organization="org_123", - ) - - body = mock_client.post.call_args.kwargs["json"] - assert body["user_profile"]["email"] == "u@e.com" - assert body["user_profile"]["username"] == "jdoe" - assert body["user_profile"]["phone_number"] == "+1234567890" - assert body["user_profile"]["given_name"] == "Jane" - assert body["user_profile"]["family_name"] == "Doe" - assert body["user_profile"]["nickname"] == "jd" - assert body["user_profile"]["picture"] == "https://example.com/pic.jpg" - assert body["user_profile"]["user_metadata"] == {"role": "admin"} - assert body["organization"] == "org_123" - - -@pytest.mark.asyncio -async def test_passkey_signup_challenge_minimal_body(server_client, mocker): - mock_response = _mock_response(200, SIGNUP_CHALLENGE_RESPONSE) - mock_client = AsyncMock() - mock_client.post = AsyncMock(return_value=mock_response) - mock_client.__aenter__ = AsyncMock(return_value=mock_client) - mock_client.__aexit__ = AsyncMock(return_value=None) - mocker.patch.object(server_client, "_get_http_client", return_value=mock_client) - - await server_client.passkey_signup_challenge() - - body = mock_client.post.call_args.kwargs["json"] - assert body == {"client_id": "test_client_id", "client_secret": "test_client_secret"} - assert "user_profile" not in body - assert "realm" not in body - assert "organization" not in body - - -@pytest.mark.asyncio -async def test_passkey_signup_challenge_api_error(server_client, mocker): - error_resp = _mock_response( - 403, - {"error": "access_denied", "error_description": "Passkey not enabled"}, - ) - mock_client = AsyncMock() - mock_client.post = AsyncMock(return_value=error_resp) - mock_client.__aenter__ = AsyncMock(return_value=mock_client) - mock_client.__aexit__ = AsyncMock(return_value=None) - mocker.patch.object(server_client, "_get_http_client", return_value=mock_client) - - with pytest.raises(ApiError) as exc: - await server_client.passkey_signup_challenge(email="test@example.com") - assert "access_denied" in str(exc.value) or "Passkey not enabled" in str(exc.value) - - -@pytest.mark.asyncio -async def test_passkey_signup_challenge_non_json_error(server_client, mocker): - resp = httpx.Response( - status_code=502, - content=b"Bad Gateway", - headers={"content-type": "text/html"}, - request=httpx.Request("POST", "https://auth0.local/passkey/register"), - ) - mock_client = AsyncMock() - mock_client.post = AsyncMock(return_value=resp) - mock_client.__aenter__ = AsyncMock(return_value=mock_client) - mock_client.__aexit__ = AsyncMock(return_value=None) - mocker.patch.object(server_client, "_get_http_client", return_value=mock_client) - - with pytest.raises(ApiError) as exc: - await server_client.passkey_signup_challenge() - assert "502" in str(exc.value) or "passkey_challenge_error" in str(exc.value) - - -@pytest.mark.asyncio -async def test_passkey_signup_challenge_network_error(server_client, mocker): - mock_client = AsyncMock() - mock_client.post = AsyncMock(side_effect=Exception("Connection refused")) - mock_client.__aenter__ = AsyncMock(return_value=mock_client) - mock_client.__aexit__ = AsyncMock(return_value=None) - mocker.patch.object(server_client, "_get_http_client", return_value=mock_client) - - with pytest.raises(ApiError) as exc: - await server_client.passkey_signup_challenge() - assert "Passkey signup challenge failed" in str(exc.value) - - -# ============================================================================= -# passkey_login_challenge -# ============================================================================= - - -@pytest.mark.asyncio -async def test_passkey_login_challenge_success(server_client, mocker): - mock_response = _mock_response(200, LOGIN_CHALLENGE_RESPONSE) - mock_client = AsyncMock() - mock_client.post = AsyncMock(return_value=mock_response) - mock_client.__aenter__ = AsyncMock(return_value=mock_client) - mock_client.__aexit__ = AsyncMock(return_value=None) - mocker.patch.object(server_client, "_get_http_client", return_value=mock_client) - - result = await server_client.passkey_login_challenge( - connection="Username-Password-Authentication", - organization="org_abc", - ) - - assert isinstance(result, PasskeyLoginChallengeResponse) - assert result.auth_session == "session_login_xyz" - assert result.authn_params_public_key.challenge == "bG9naW4tY2hhbGxlbmdl" - assert result.authn_params_public_key.rp_id == "auth0.local" - assert result.authn_params_public_key.user_verification == "preferred" - - body = mock_client.post.call_args.kwargs["json"] - assert body["client_id"] == "test_client_id" - assert body["realm"] == "Username-Password-Authentication" - assert body["organization"] == "org_abc" - - -@pytest.mark.asyncio -async def test_passkey_login_challenge_with_username(server_client, mocker): - mock_response = _mock_response(200, LOGIN_CHALLENGE_RESPONSE) - mock_client = AsyncMock() - mock_client.post = AsyncMock(return_value=mock_response) - mock_client.__aenter__ = AsyncMock(return_value=mock_client) - mock_client.__aexit__ = AsyncMock(return_value=None) - mocker.patch.object(server_client, "_get_http_client", return_value=mock_client) - - await server_client.passkey_login_challenge(username="jane@example.com") - - body = mock_client.post.call_args.kwargs["json"] - assert body["username"] == "jane@example.com" - - -@pytest.mark.asyncio -async def test_passkey_login_challenge_api_error(server_client, mocker): - error_resp = _mock_response( - 400, - {"error": "invalid_request", "error_description": "Missing client_id"}, - ) - mock_client = AsyncMock() - mock_client.post = AsyncMock(return_value=error_resp) - mock_client.__aenter__ = AsyncMock(return_value=mock_client) - mock_client.__aexit__ = AsyncMock(return_value=None) - mocker.patch.object(server_client, "_get_http_client", return_value=mock_client) - - with pytest.raises(ApiError): - await server_client.passkey_login_challenge() - - -@pytest.mark.asyncio -async def test_passkey_login_challenge_network_error(server_client, mocker): - mock_client = AsyncMock() - mock_client.post = AsyncMock(side_effect=Exception("timeout")) - mock_client.__aenter__ = AsyncMock(return_value=mock_client) - mock_client.__aexit__ = AsyncMock(return_value=None) - mocker.patch.object(server_client, "_get_http_client", return_value=mock_client) - - with pytest.raises(ApiError): - await server_client.passkey_login_challenge() - - -# ============================================================================= -# signin_with_passkey -# ============================================================================= - - -@pytest.fixture -def authn_response(): - return PasskeyAuthResponse( - id="cred_abc123", - raw_id="Y3JlZF9hYmMxMjM", - type="public-key", - authenticator_attachment="platform", - response={ - "clientDataJSON": "eyJ0eXBlIjoid2ViYXV0aG4uZ2V0In0", - "authenticatorData": "SZYN5YgOjGh0NBcPZHZgW4_krrmihjLHmVzzuoMdl2M", - "signature": "MEUCIQC", - "userHandle": "dXNlcl8x", - }, - ) - - -@pytest.mark.asyncio -async def test_signin_with_passkey_success(server_client, authn_response, mocker): - mock_response = _mock_response(200, TOKEN_RESPONSE) - mock_client = AsyncMock() - mock_client.post = AsyncMock(return_value=mock_response) - mock_client.__aenter__ = AsyncMock(return_value=mock_client) - mock_client.__aexit__ = AsyncMock(return_value=None) - mocker.patch.object(server_client, "_get_http_client", return_value=mock_client) - mocker.patch.object( - server_client, - "_get_oidc_metadata_cached", - return_value={"token_endpoint": "https://auth0.local/oauth/token"}, - ) - - result = await server_client.signin_with_passkey( - auth_session="session_xyz", - authn_response=authn_response, - scope="openid profile", - audience="https://api.example.com", - connection="Username-Password-Authentication", - organization="org_abc", - ) - - assert isinstance(result, PasskeyTokenResponse) - assert result.access_token == "at_passkey_123" - assert result.token_type == "Bearer" - assert abs(result.expires_at - (int(time.time()) + 86400)) <= 2 - - body = mock_client.post.call_args.kwargs["json"] - assert body["grant_type"] == "urn:okta:params:oauth:grant-type:webauthn" - assert body["client_id"] == "test_client_id" - assert body["client_secret"] == "test_client_secret" - assert body["auth_session"] == "session_xyz" - assert body["scope"] == "openid profile" - assert body["audience"] == "https://api.example.com" - assert body["realm"] == "Username-Password-Authentication" - assert body["organization"] == "org_abc" - assert body["authn_response"]["rawId"] == "Y3JlZF9hYmMxMjM" - assert body["authn_response"]["authenticatorAttachment"] == "platform" - assert "raw_id" not in body["authn_response"] - - -@pytest.mark.asyncio -async def test_signin_with_passkey_uses_json_content_type(server_client, authn_response, mocker): - mock_response = _mock_response(200, TOKEN_RESPONSE) - mock_client = AsyncMock() - mock_client.post = AsyncMock(return_value=mock_response) - mock_client.__aenter__ = AsyncMock(return_value=mock_client) - mock_client.__aexit__ = AsyncMock(return_value=None) - mocker.patch.object(server_client, "_get_http_client", return_value=mock_client) - mocker.patch.object( - server_client, - "_get_oidc_metadata_cached", - return_value={"token_endpoint": "https://auth0.local/oauth/token"}, - ) - - await server_client.signin_with_passkey( - auth_session="s", - authn_response=authn_response, - ) - - call_kwargs = mock_client.post.call_args.kwargs - assert "json" in call_kwargs - assert "data" not in call_kwargs - - -@pytest.mark.asyncio -@pytest.mark.parametrize("auth_session", [None, ""]) -async def test_signin_with_passkey_missing_auth_session( - server_client, authn_response, auth_session -): - with pytest.raises(MissingRequiredArgumentError): - await server_client.signin_with_passkey( - auth_session=auth_session, - authn_response=authn_response, - ) - - -@pytest.mark.asyncio -async def test_signin_with_passkey_missing_authn_response(server_client): - with pytest.raises(MissingRequiredArgumentError): - await server_client.signin_with_passkey( - auth_session="session_abc", - authn_response=None, - ) - - -@pytest.mark.asyncio -async def test_signin_with_passkey_api_error(server_client, authn_response, mocker): - error_resp = _mock_response( - 401, - {"error": "invalid_grant", "error_description": "Invalid auth_session"}, - ) - mock_client = AsyncMock() - mock_client.post = AsyncMock(return_value=error_resp) - mock_client.__aenter__ = AsyncMock(return_value=mock_client) - mock_client.__aexit__ = AsyncMock(return_value=None) - mocker.patch.object(server_client, "_get_http_client", return_value=mock_client) - mocker.patch.object( - server_client, - "_get_oidc_metadata_cached", - return_value={"token_endpoint": "https://auth0.local/oauth/token"}, - ) - - with pytest.raises(ApiError) as exc: - await server_client.signin_with_passkey( - auth_session="expired_session", - authn_response=authn_response, - ) - assert "invalid_grant" in str(exc.value) or "Invalid auth_session" in str(exc.value) - - -@pytest.mark.asyncio -async def test_signin_with_passkey_missing_token_endpoint(server_client, authn_response, mocker): - mocker.patch.object( - server_client, - "_get_oidc_metadata_cached", - return_value={}, - ) - - with pytest.raises(ApiError) as exc: - await server_client.signin_with_passkey( - auth_session="session", - authn_response=authn_response, - ) - assert "token endpoint" in str(exc.value).lower() - - -@pytest.mark.asyncio -async def test_signin_with_passkey_network_error(server_client, authn_response, mocker): - mock_client = AsyncMock() - mock_client.post = AsyncMock(side_effect=Exception("Connection reset")) - mock_client.__aenter__ = AsyncMock(return_value=mock_client) - mock_client.__aexit__ = AsyncMock(return_value=None) - mocker.patch.object(server_client, "_get_http_client", return_value=mock_client) - mocker.patch.object( - server_client, - "_get_oidc_metadata_cached", - return_value={"token_endpoint": "https://auth0.local/oauth/token"}, - ) - - with pytest.raises(ApiError): - await server_client.signin_with_passkey( - auth_session="session", - authn_response=authn_response, - ) - - -@pytest.mark.asyncio -async def test_signin_with_passkey_no_client_secret(mocker): - client = ServerClient( - domain="auth0.local", - client_id="public_client", - client_secret=None, - state_store=AsyncMock(), - transaction_store=AsyncMock(), - secret="test-secret", - ) - - mock_response = _mock_response(200, TOKEN_RESPONSE) - mock_client = AsyncMock() - mock_client.post = AsyncMock(return_value=mock_response) - mock_client.__aenter__ = AsyncMock(return_value=mock_client) - mock_client.__aexit__ = AsyncMock(return_value=None) - mocker.patch.object(client, "_get_http_client", return_value=mock_client) - mocker.patch.object( - client, - "_get_oidc_metadata_cached", - return_value={"token_endpoint": "https://auth0.local/oauth/token"}, - ) - - authn_resp = PasskeyAuthResponse( - id="cred", - raw_id="cmF3", - type="public-key", - response={"clientDataJSON": "abc", "authenticatorData": "def", "signature": "ghi"}, - ) - - await client.signin_with_passkey( - auth_session="session", - authn_response=authn_resp, - ) - - body = mock_client.post.call_args.kwargs["json"] - assert "client_secret" not in body - assert body["client_id"] == "public_client" - - -@pytest.mark.asyncio -async def test_signup_challenge_repr_redacts_auth_session(): - resp = PasskeySignupChallengeResponse.model_validate(SIGNUP_CHALLENGE_RESPONSE) - repr_str = repr(resp) - assert "session_abc123" not in repr_str - assert "[REDACTED]" in repr_str - - -@pytest.mark.asyncio -async def test_login_challenge_repr_redacts_auth_session(): - resp = PasskeyLoginChallengeResponse.model_validate(LOGIN_CHALLENGE_RESPONSE) - repr_str = repr(resp) - assert "session_login_xyz" not in repr_str - assert "[REDACTED]" in repr_str - - -def test_passkey_token_response_repr_redacts_tokens(): - resp = PasskeyTokenResponse( - access_token="secret_at_value", - token_type="Bearer", - expires_in=86400, - id_token="secret_id_token", - refresh_token="secret_rt_value", - ) - repr_str = repr(resp) - assert "secret_at_value" not in repr_str - assert "secret_id_token" not in repr_str - assert "secret_rt_value" not in repr_str - assert "[REDACTED]" in repr_str - assert "86400" in repr_str - - -# ============================================================================= -# expires_at edge cases -# ============================================================================= - - -@pytest.mark.asyncio -async def test_signin_with_passkey_preserves_server_expires_at( - server_client, authn_response, mocker -): - token_data = { - "access_token": "at_123", - "token_type": "Bearer", - "expires_in": 3600, - "expires_at": 9999999999, - } - mock_response = _mock_response(200, token_data) - mock_client = AsyncMock() - mock_client.post = AsyncMock(return_value=mock_response) - mock_client.__aenter__ = AsyncMock(return_value=mock_client) - mock_client.__aexit__ = AsyncMock(return_value=None) - mocker.patch.object(server_client, "_get_http_client", return_value=mock_client) - mocker.patch.object( - server_client, - "_get_oidc_metadata_cached", - return_value={"token_endpoint": "https://auth0.local/oauth/token"}, - ) - - result = await server_client.signin_with_passkey( - auth_session="session", authn_response=authn_response - ) - - assert result.expires_at == 9999999999 - - -@pytest.mark.asyncio -async def test_signin_with_passkey_missing_expires_at_calculates( - server_client, authn_response, mocker -): - token_data = { - "access_token": "at_123", - "token_type": "Bearer", - "expires_in": 60, - } - mock_response = _mock_response(200, token_data) - mock_client = AsyncMock() - mock_client.post = AsyncMock(return_value=mock_response) - mock_client.__aenter__ = AsyncMock(return_value=mock_client) - mock_client.__aexit__ = AsyncMock(return_value=None) - mocker.patch.object(server_client, "_get_http_client", return_value=mock_client) - mocker.patch.object( - server_client, - "_get_oidc_metadata_cached", - return_value={"token_endpoint": "https://auth0.local/oauth/token"}, - ) - - result = await server_client.signin_with_passkey( - auth_session="session", authn_response=authn_response - ) - - assert abs(result.expires_at - (int(time.time()) + 60)) <= 2 diff --git a/src/auth0_server_python/tests/test_server_client.py b/src/auth0_server_python/tests/test_server_client.py index 47ba774..f987567 100644 --- a/src/auth0_server_python/tests/test_server_client.py +++ b/src/auth0_server_python/tests/test_server_client.py @@ -4816,3 +4816,797 @@ async def _fake_fetch(self, domain): assert exc.value.mfa_requirements is not None finally: ServerClient._fetch_oidc_metadata = original_fetch + + +# ============================================================================= +# PASSKEY AUTHENTICATION +# ============================================================================= + +_PASSKEY_SIGNUP_CHALLENGE_RESPONSE = { + "auth_session": "session_abc123", + "authn_params_public_key": { + "challenge": "dGVzdC1jaGFsbGVuZ2U", + "rp": {"id": "auth0.local", "name": "Test App"}, + "user": {"id": "dXNlcl8x", "name": "user@example.com", "displayName": "Jane"}, + "pubKeyCredParams": [{"type": "public-key", "alg": -7}], + "authenticatorSelection": { + "residentKey": "required", + "userVerification": "preferred", + }, + "timeout": 60000, + }, +} + +_PASSKEY_LOGIN_CHALLENGE_RESPONSE = { + "auth_session": "session_login_xyz", + "authn_params_public_key": { + "challenge": "bG9naW4tY2hhbGxlbmdl", + "rpId": "auth0.local", + "timeout": 60000, + "userVerification": "preferred", + }, +} + +_PASSKEY_TOKEN_RESPONSE = { + "access_token": "at_passkey_123", + "id_token": "eyJ.test.jwt", + "token_type": "Bearer", + "expires_in": 86400, + "scope": "openid profile", +} + + +def _make_passkey_authn_response(): + from auth0_server_python.auth_types import PasskeyAuthResponse + return PasskeyAuthResponse( + id="cred_abc123", + raw_id="Y3JlZF9hYmMxMjM", + type="public-key", + authenticator_attachment="platform", + response={ + "clientDataJSON": "eyJ0eXBlIjoid2ViYXV0aG4uZ2V0In0", + "authenticatorData": "SZYN5YgOjGh0NBcPZHZgW4_krrmihjLHmVzzuoMdl2M", + "signature": "MEUCIQC", + "userHandle": "dXNlcl8x", + }, + ) + + +@pytest.mark.asyncio +async def test_passkey_signup_challenge_success(mocker): + from auth0_server_python.auth_types import PasskeySignupChallengeResponse, PasskeyUserProfile + client = ServerClient( + domain="auth0.local", + client_id="test_client_id", + client_secret="test_client_secret", + state_store=AsyncMock(), + transaction_store=AsyncMock(), + secret="test-secret-value", + ) + mock_post = mocker.patch("httpx.AsyncClient.post", new_callable=AsyncMock) + mock_response = AsyncMock() + mock_response.status_code = 200 + mock_response.json = MagicMock(return_value=_PASSKEY_SIGNUP_CHALLENGE_RESPONSE) + mock_post.return_value = mock_response + + result = await client.passkey_signup_challenge( + user_profile=PasskeyUserProfile(email="user@example.com", name="Jane Doe"), + connection="Username-Password-Authentication", + ) + + assert isinstance(result, PasskeySignupChallengeResponse) + assert result.auth_session == "session_abc123" + assert result.authn_params_public_key.challenge == "dGVzdC1jaGFsbGVuZ2U" + assert result.authn_params_public_key.rp.id == "auth0.local" + assert result.authn_params_public_key.user.display_name == "Jane" + assert result.authn_params_public_key.pub_key_cred_params[0].alg == -7 + assert result.authn_params_public_key.authenticator_selection.resident_key == "required" + + mock_post.assert_awaited_once() + args, kwargs = mock_post.call_args + assert "/passkey/register" in args[0] + body = kwargs["json"] + assert body["client_id"] == "test_client_id" + assert body["client_secret"] == "test_client_secret" + assert body["user_profile"]["email"] == "user@example.com" + assert body["user_profile"]["name"] == "Jane Doe" + assert body["realm"] == "Username-Password-Authentication" + + +@pytest.mark.asyncio +async def test_passkey_signup_challenge_user_profile_fields(mocker): + from auth0_server_python.auth_types import PasskeyUserProfile + client = ServerClient( + domain="auth0.local", + client_id="test_client_id", + client_secret="test_client_secret", + state_store=AsyncMock(), + transaction_store=AsyncMock(), + secret="test-secret-value", + ) + mock_post = mocker.patch("httpx.AsyncClient.post", new_callable=AsyncMock) + mock_response = AsyncMock() + mock_response.status_code = 200 + mock_response.json = MagicMock(return_value=_PASSKEY_SIGNUP_CHALLENGE_RESPONSE) + mock_post.return_value = mock_response + + await client.passkey_signup_challenge( + user_profile=PasskeyUserProfile( + email="u@e.com", + username="jdoe", + phone_number="+1234567890", + given_name="Jane", + family_name="Doe", + nickname="jd", + picture="https://example.com/pic.jpg", + user_metadata={"role": "admin"}, + ), + organization="org_123", + ) + + args, kwargs = mock_post.call_args + body = kwargs["json"] + assert body["user_profile"]["email"] == "u@e.com" + assert body["user_profile"]["username"] == "jdoe" + assert body["user_profile"]["phone_number"] == "+1234567890" + assert body["user_profile"]["given_name"] == "Jane" + assert body["user_profile"]["family_name"] == "Doe" + assert body["user_profile"]["nickname"] == "jd" + assert body["user_profile"]["picture"] == "https://example.com/pic.jpg" + assert body["user_profile"]["user_metadata"] == {"role": "admin"} + assert body["organization"] == "org_123" + + +@pytest.mark.asyncio +async def test_passkey_signup_challenge_minimal_body(mocker): + client = ServerClient( + domain="auth0.local", + client_id="test_client_id", + client_secret="test_client_secret", + state_store=AsyncMock(), + transaction_store=AsyncMock(), + secret="test-secret-value", + ) + mock_post = mocker.patch("httpx.AsyncClient.post", new_callable=AsyncMock) + mock_response = AsyncMock() + mock_response.status_code = 200 + mock_response.json = MagicMock(return_value=_PASSKEY_SIGNUP_CHALLENGE_RESPONSE) + mock_post.return_value = mock_response + + await client.passkey_signup_challenge() + + args, kwargs = mock_post.call_args + body = kwargs["json"] + assert body == {"client_id": "test_client_id", "client_secret": "test_client_secret"} + assert "user_profile" not in body + assert "realm" not in body + assert "organization" not in body + + +@pytest.mark.asyncio +async def test_passkey_signup_challenge_api_error(mocker): + from auth0_server_python.auth_types import PasskeyUserProfile + from auth0_server_python.error import PasskeyError + client = ServerClient( + domain="auth0.local", + client_id="test_client_id", + client_secret="test_client_secret", + state_store=AsyncMock(), + transaction_store=AsyncMock(), + secret="test-secret-value", + ) + mock_post = mocker.patch("httpx.AsyncClient.post", new_callable=AsyncMock) + mock_response = AsyncMock() + mock_response.status_code = 403 + mock_response.json = MagicMock(return_value={ + "error": "access_denied", + "error_description": "Passkey not enabled", + }) + mock_post.return_value = mock_response + + with pytest.raises(PasskeyError) as exc: + await client.passkey_signup_challenge( + user_profile=PasskeyUserProfile(email="test@example.com") + ) + assert "access_denied" in str(exc.value) or "Passkey not enabled" in str(exc.value) + + +@pytest.mark.asyncio +async def test_passkey_signup_challenge_non_json_error(mocker): + from auth0_server_python.error import PasskeyError + client = ServerClient( + domain="auth0.local", + client_id="test_client_id", + client_secret="test_client_secret", + state_store=AsyncMock(), + transaction_store=AsyncMock(), + secret="test-secret-value", + ) + mock_post = mocker.patch("httpx.AsyncClient.post", new_callable=AsyncMock) + mock_response = AsyncMock() + mock_response.status_code = 502 + mock_response.json = MagicMock(side_effect=json.JSONDecodeError("bad", "", 0)) + mock_post.return_value = mock_response + + with pytest.raises(PasskeyError) as exc: + await client.passkey_signup_challenge() + assert "502" in str(exc.value) or "passkey_challenge_error" in str(exc.value) + + +@pytest.mark.asyncio +async def test_passkey_signup_challenge_network_error(mocker): + from auth0_server_python.error import PasskeyError + client = ServerClient( + domain="auth0.local", + client_id="test_client_id", + client_secret="test_client_secret", + state_store=AsyncMock(), + transaction_store=AsyncMock(), + secret="test-secret-value", + ) + mock_post = mocker.patch("httpx.AsyncClient.post", new_callable=AsyncMock) + mock_post.side_effect = Exception("Connection refused") + + with pytest.raises(PasskeyError) as exc: + await client.passkey_signup_challenge() + assert "Passkey signup challenge failed" in str(exc.value) + + +@pytest.mark.asyncio +async def test_passkey_login_challenge_success(mocker): + from auth0_server_python.auth_types import PasskeyLoginChallengeResponse + client = ServerClient( + domain="auth0.local", + client_id="test_client_id", + client_secret="test_client_secret", + state_store=AsyncMock(), + transaction_store=AsyncMock(), + secret="test-secret-value", + ) + mock_post = mocker.patch("httpx.AsyncClient.post", new_callable=AsyncMock) + mock_response = AsyncMock() + mock_response.status_code = 200 + mock_response.json = MagicMock(return_value=_PASSKEY_LOGIN_CHALLENGE_RESPONSE) + mock_post.return_value = mock_response + + result = await client.passkey_login_challenge( + connection="Username-Password-Authentication", + organization="org_abc", + ) + + assert isinstance(result, PasskeyLoginChallengeResponse) + assert result.auth_session == "session_login_xyz" + assert result.authn_params_public_key.challenge == "bG9naW4tY2hhbGxlbmdl" + assert result.authn_params_public_key.rp_id == "auth0.local" + assert result.authn_params_public_key.user_verification == "preferred" + + args, kwargs = mock_post.call_args + body = kwargs["json"] + assert body["client_id"] == "test_client_id" + assert body["realm"] == "Username-Password-Authentication" + assert body["organization"] == "org_abc" + + +@pytest.mark.asyncio +async def test_passkey_login_challenge_with_username(mocker): + client = ServerClient( + domain="auth0.local", + client_id="test_client_id", + client_secret="test_client_secret", + state_store=AsyncMock(), + transaction_store=AsyncMock(), + secret="test-secret-value", + ) + mock_post = mocker.patch("httpx.AsyncClient.post", new_callable=AsyncMock) + mock_response = AsyncMock() + mock_response.status_code = 200 + mock_response.json = MagicMock(return_value=_PASSKEY_LOGIN_CHALLENGE_RESPONSE) + mock_post.return_value = mock_response + + await client.passkey_login_challenge(username="jane@example.com") + + args, kwargs = mock_post.call_args + body = kwargs["json"] + assert body["username"] == "jane@example.com" + + +@pytest.mark.asyncio +async def test_passkey_login_challenge_api_error(mocker): + from auth0_server_python.error import PasskeyError + client = ServerClient( + domain="auth0.local", + client_id="test_client_id", + client_secret="test_client_secret", + state_store=AsyncMock(), + transaction_store=AsyncMock(), + secret="test-secret-value", + ) + mock_post = mocker.patch("httpx.AsyncClient.post", new_callable=AsyncMock) + mock_response = AsyncMock() + mock_response.status_code = 400 + mock_response.json = MagicMock(return_value={ + "error": "invalid_request", + "error_description": "Missing client_id", + }) + mock_post.return_value = mock_response + + with pytest.raises(PasskeyError): + await client.passkey_login_challenge() + + +@pytest.mark.asyncio +async def test_passkey_login_challenge_network_error(mocker): + from auth0_server_python.error import PasskeyError + client = ServerClient( + domain="auth0.local", + client_id="test_client_id", + client_secret="test_client_secret", + state_store=AsyncMock(), + transaction_store=AsyncMock(), + secret="test-secret-value", + ) + mock_post = mocker.patch("httpx.AsyncClient.post", new_callable=AsyncMock) + mock_post.side_effect = Exception("timeout") + + with pytest.raises(PasskeyError): + await client.passkey_login_challenge() + + +@pytest.mark.asyncio +async def test_signin_with_passkey_success(mocker): + from auth0_server_python.auth_types import PasskeyTokenResponse + from auth0_server_python.error import PasskeyError + client = ServerClient( + domain="auth0.local", + client_id="test_client_id", + client_secret="test_client_secret", + state_store=AsyncMock(), + transaction_store=AsyncMock(), + secret="test-secret-value", + ) + mocker.patch.object( + client, + "_get_oidc_metadata_cached", + return_value={"token_endpoint": "https://auth0.local/oauth/token"}, + ) + mock_post = mocker.patch("httpx.AsyncClient.post", new_callable=AsyncMock) + mock_response = AsyncMock() + mock_response.status_code = 200 + mock_response.json = MagicMock(return_value=_PASSKEY_TOKEN_RESPONSE) + mock_post.return_value = mock_response + authn_response = _make_passkey_authn_response() + + result = await client.signin_with_passkey( + auth_session="session_xyz", + authn_response=authn_response, + scope="openid profile", + audience="https://api.example.com", + connection="Username-Password-Authentication", + organization="org_abc", + ) + + assert isinstance(result, PasskeyTokenResponse) + assert result.access_token == "at_passkey_123" + assert result.token_type == "Bearer" + assert abs(result.expires_at - (int(time.time()) + 86400)) <= 2 + + mock_post.assert_awaited_once() + args, kwargs = mock_post.call_args + body = kwargs["json"] + assert body["grant_type"] == "urn:okta:params:oauth:grant-type:webauthn" + assert body["client_id"] == "test_client_id" + assert body["client_secret"] == "test_client_secret" + assert body["auth_session"] == "session_xyz" + assert body["scope"] == "openid profile" + assert body["audience"] == "https://api.example.com" + assert body["realm"] == "Username-Password-Authentication" + assert body["organization"] == "org_abc" + assert body["authn_response"]["rawId"] == "Y3JlZF9hYmMxMjM" + assert body["authn_response"]["authenticatorAttachment"] == "platform" + assert "raw_id" not in body["authn_response"] + + +@pytest.mark.asyncio +async def test_signin_with_passkey_uses_json_content_type(mocker): + client = ServerClient( + domain="auth0.local", + client_id="test_client_id", + client_secret="test_client_secret", + state_store=AsyncMock(), + transaction_store=AsyncMock(), + secret="test-secret-value", + ) + mocker.patch.object( + client, + "_get_oidc_metadata_cached", + return_value={"token_endpoint": "https://auth0.local/oauth/token"}, + ) + mock_post = mocker.patch("httpx.AsyncClient.post", new_callable=AsyncMock) + mock_response = AsyncMock() + mock_response.status_code = 200 + mock_response.json = MagicMock(return_value=_PASSKEY_TOKEN_RESPONSE) + mock_post.return_value = mock_response + + await client.signin_with_passkey( + auth_session="s", + authn_response=_make_passkey_authn_response(), + ) + + args, kwargs = mock_post.call_args + assert "json" in kwargs + assert "data" not in kwargs + + +@pytest.mark.asyncio +@pytest.mark.parametrize("auth_session", [None, ""]) +async def test_signin_with_passkey_missing_auth_session(auth_session): + client = ServerClient( + domain="auth0.local", + client_id="test_client_id", + client_secret="test_client_secret", + state_store=AsyncMock(), + transaction_store=AsyncMock(), + secret="test-secret-value", + ) + with pytest.raises(MissingRequiredArgumentError): + await client.signin_with_passkey( + auth_session=auth_session, + authn_response=_make_passkey_authn_response(), + ) + + +@pytest.mark.asyncio +async def test_signin_with_passkey_missing_authn_response(): + client = ServerClient( + domain="auth0.local", + client_id="test_client_id", + client_secret="test_client_secret", + state_store=AsyncMock(), + transaction_store=AsyncMock(), + secret="test-secret-value", + ) + with pytest.raises(MissingRequiredArgumentError): + await client.signin_with_passkey( + auth_session="session_abc", + authn_response=None, + ) + + +@pytest.mark.asyncio +async def test_signin_with_passkey_api_error(mocker): + from auth0_server_python.error import PasskeyError + client = ServerClient( + domain="auth0.local", + client_id="test_client_id", + client_secret="test_client_secret", + state_store=AsyncMock(), + transaction_store=AsyncMock(), + secret="test-secret-value", + ) + mocker.patch.object( + client, + "_get_oidc_metadata_cached", + return_value={"token_endpoint": "https://auth0.local/oauth/token"}, + ) + mock_post = mocker.patch("httpx.AsyncClient.post", new_callable=AsyncMock) + mock_response = AsyncMock() + mock_response.status_code = 401 + mock_response.json = MagicMock(return_value={ + "error": "invalid_grant", + "error_description": "Invalid auth_session", + }) + mock_post.return_value = mock_response + + with pytest.raises(PasskeyError) as exc: + await client.signin_with_passkey( + auth_session="expired_session", + authn_response=_make_passkey_authn_response(), + ) + assert "invalid_grant" in str(exc.value) or "Invalid auth_session" in str(exc.value) + + +@pytest.mark.asyncio +async def test_signin_with_passkey_missing_token_endpoint(mocker): + from auth0_server_python.error import PasskeyError + client = ServerClient( + domain="auth0.local", + client_id="test_client_id", + client_secret="test_client_secret", + state_store=AsyncMock(), + transaction_store=AsyncMock(), + secret="test-secret-value", + ) + mocker.patch.object(client, "_get_oidc_metadata_cached", return_value={}) + + with pytest.raises(PasskeyError) as exc: + await client.signin_with_passkey( + auth_session="session", + authn_response=_make_passkey_authn_response(), + ) + assert "token endpoint" in str(exc.value).lower() + + +@pytest.mark.asyncio +async def test_signin_with_passkey_network_error(mocker): + from auth0_server_python.error import PasskeyError + client = ServerClient( + domain="auth0.local", + client_id="test_client_id", + client_secret="test_client_secret", + state_store=AsyncMock(), + transaction_store=AsyncMock(), + secret="test-secret-value", + ) + mocker.patch.object( + client, + "_get_oidc_metadata_cached", + return_value={"token_endpoint": "https://auth0.local/oauth/token"}, + ) + mock_post = mocker.patch("httpx.AsyncClient.post", new_callable=AsyncMock) + mock_post.side_effect = Exception("Connection reset") + + with pytest.raises(PasskeyError): + await client.signin_with_passkey( + auth_session="session", + authn_response=_make_passkey_authn_response(), + ) + + +@pytest.mark.asyncio +async def test_signin_with_passkey_no_client_secret(mocker): + client = ServerClient( + domain="auth0.local", + client_id="public_client", + client_secret=None, + state_store=AsyncMock(), + transaction_store=AsyncMock(), + secret="test-secret", + ) + mocker.patch.object( + client, + "_get_oidc_metadata_cached", + return_value={"token_endpoint": "https://auth0.local/oauth/token"}, + ) + mock_post = mocker.patch("httpx.AsyncClient.post", new_callable=AsyncMock) + mock_response = AsyncMock() + mock_response.status_code = 200 + mock_response.json = MagicMock(return_value=_PASSKEY_TOKEN_RESPONSE) + mock_post.return_value = mock_response + + from auth0_server_python.auth_types import PasskeyAuthResponse + authn_resp = PasskeyAuthResponse( + id="cred", + raw_id="cmF3", + type="public-key", + response={"clientDataJSON": "abc", "authenticatorData": "def", "signature": "ghi"}, + ) + await client.signin_with_passkey(auth_session="session", authn_response=authn_resp) + + args, kwargs = mock_post.call_args + body = kwargs["json"] + assert "client_secret" not in body + assert body["client_id"] == "public_client" + + +def test_passkey_signup_challenge_repr_redacts_auth_session(): + from auth0_server_python.auth_types import PasskeySignupChallengeResponse + resp = PasskeySignupChallengeResponse.model_validate(_PASSKEY_SIGNUP_CHALLENGE_RESPONSE) + repr_str = repr(resp) + assert "session_abc123" not in repr_str + assert "[REDACTED]" in repr_str + + +def test_passkey_login_challenge_repr_redacts_auth_session(): + from auth0_server_python.auth_types import PasskeyLoginChallengeResponse + resp = PasskeyLoginChallengeResponse.model_validate(_PASSKEY_LOGIN_CHALLENGE_RESPONSE) + repr_str = repr(resp) + assert "session_login_xyz" not in repr_str + assert "[REDACTED]" in repr_str + + +def test_passkey_token_response_repr_redacts_tokens(): + from auth0_server_python.auth_types import PasskeyTokenResponse + resp = PasskeyTokenResponse( + access_token="secret_at_value", + token_type="Bearer", + expires_in=86400, + id_token="secret_id_token", + refresh_token="secret_rt_value", + ) + repr_str = repr(resp) + assert "secret_at_value" not in repr_str + assert "secret_id_token" not in repr_str + assert "secret_rt_value" not in repr_str + assert "[REDACTED]" in repr_str + assert "86400" in repr_str + + +@pytest.mark.asyncio +async def test_signin_with_passkey_preserves_server_expires_at(mocker): + client = ServerClient( + domain="auth0.local", + client_id="test_client_id", + client_secret="test_client_secret", + state_store=AsyncMock(), + transaction_store=AsyncMock(), + secret="test-secret-value", + ) + mocker.patch.object( + client, + "_get_oidc_metadata_cached", + return_value={"token_endpoint": "https://auth0.local/oauth/token"}, + ) + mock_post = mocker.patch("httpx.AsyncClient.post", new_callable=AsyncMock) + mock_response = AsyncMock() + mock_response.status_code = 200 + mock_response.json = MagicMock(return_value={ + "access_token": "at_123", + "token_type": "Bearer", + "expires_in": 3600, + "expires_at": 9999999999, + }) + mock_post.return_value = mock_response + + from auth0_server_python.auth_types import PasskeyTokenResponse + result = await client.signin_with_passkey( + auth_session="session", + authn_response=_make_passkey_authn_response(), + ) + assert result.expires_at == 9999999999 + + +@pytest.mark.asyncio +async def test_signin_with_passkey_missing_expires_at_calculates(mocker): + client = ServerClient( + domain="auth0.local", + client_id="test_client_id", + client_secret="test_client_secret", + state_store=AsyncMock(), + transaction_store=AsyncMock(), + secret="test-secret-value", + ) + mocker.patch.object( + client, + "_get_oidc_metadata_cached", + return_value={"token_endpoint": "https://auth0.local/oauth/token"}, + ) + mock_post = mocker.patch("httpx.AsyncClient.post", new_callable=AsyncMock) + mock_response = AsyncMock() + mock_response.status_code = 200 + mock_response.json = MagicMock(return_value={ + "access_token": "at_123", + "token_type": "Bearer", + "expires_in": 60, + }) + mock_post.return_value = mock_response + + result = await client.signin_with_passkey( + auth_session="session", + authn_response=_make_passkey_authn_response(), + ) + assert abs(result.expires_at - (int(time.time()) + 60)) <= 2 + + +@pytest.mark.asyncio +async def test_signin_with_passkey_dpop_attaches_proof_header(mocker): + import base64 + import json as _json + from jwcrypto import jwk as jwk_module + client = ServerClient( + domain="auth0.local", + client_id="test_client_id", + client_secret="test_client_secret", + state_store=AsyncMock(), + transaction_store=AsyncMock(), + secret="test-secret-value", + ) + mocker.patch.object( + client, + "_get_oidc_metadata_cached", + return_value={"token_endpoint": "https://auth0.local/oauth/token"}, + ) + mock_post = mocker.patch("httpx.AsyncClient.post", new_callable=AsyncMock) + mock_response = AsyncMock() + mock_response.status_code = 200 + mock_response.json = MagicMock(return_value=_PASSKEY_TOKEN_RESPONSE) + mock_post.return_value = mock_response + + dpop_key = jwk_module.JWK.generate(kty="EC", crv="P-256") + await client.signin_with_passkey( + auth_session="session_xyz", + authn_response=_make_passkey_authn_response(), + dpop_key=dpop_key, + ) + + args, kwargs = mock_post.call_args + assert "DPoP" in kwargs["headers"] + + # Decode proof and assert no ath claim (token endpoint proof — RFC 9449 §4.2) + proof = kwargs["headers"]["DPoP"] + payload_b64 = proof.split(".")[1] + padding = 4 - len(payload_b64) % 4 + payload = _json.loads(base64.urlsafe_b64decode(payload_b64 + "=" * padding)) + assert "ath" not in payload + assert "jti" in payload + assert payload["htm"] == "POST" + assert payload["htu"] == "https://auth0.local/oauth/token" + + +@pytest.mark.asyncio +async def test_signin_with_passkey_dpop_nonce_retry(mocker): + import base64 + import json as _json + from jwcrypto import jwk as jwk_module + client = ServerClient( + domain="auth0.local", + client_id="test_client_id", + client_secret="test_client_secret", + state_store=AsyncMock(), + transaction_store=AsyncMock(), + secret="test-secret-value", + ) + mocker.patch.object( + client, + "_get_oidc_metadata_cached", + return_value={"token_endpoint": "https://auth0.local/oauth/token"}, + ) + mock_post = mocker.patch("httpx.AsyncClient.post", new_callable=AsyncMock) + + nonce_response = AsyncMock() + nonce_response.status_code = 401 + nonce_response.headers = {"DPoP-Nonce": "server-nonce-abc"} + nonce_response.json = MagicMock(return_value={"error": "use_dpop_nonce"}) + + success_response = AsyncMock() + success_response.status_code = 200 + success_response.json = MagicMock(return_value=_PASSKEY_TOKEN_RESPONSE) + + mock_post.side_effect = [nonce_response, success_response] + + dpop_key = jwk_module.JWK.generate(kty="EC", crv="P-256") + result = await client.signin_with_passkey( + auth_session="session_xyz", + authn_response=_make_passkey_authn_response(), + dpop_key=dpop_key, + ) + + assert mock_post.await_count == 2 + assert result.access_token == "at_passkey_123" + + # Second call must include the nonce in the DPoP proof + second_call_kwargs = mock_post.call_args_list[1][1] + proof = second_call_kwargs["headers"]["DPoP"] + payload_b64 = proof.split(".")[1] + padding = 4 - len(payload_b64) % 4 + payload = _json.loads(base64.urlsafe_b64decode(payload_b64 + "=" * padding)) + assert payload["nonce"] == "server-nonce-abc" + + +@pytest.mark.asyncio +async def test_signin_with_passkey_without_dpop_no_dpop_header(mocker): + client = ServerClient( + domain="auth0.local", + client_id="test_client_id", + client_secret="test_client_secret", + state_store=AsyncMock(), + transaction_store=AsyncMock(), + secret="test-secret-value", + ) + mocker.patch.object( + client, + "_get_oidc_metadata_cached", + return_value={"token_endpoint": "https://auth0.local/oauth/token"}, + ) + mock_post = mocker.patch("httpx.AsyncClient.post", new_callable=AsyncMock) + mock_response = AsyncMock() + mock_response.status_code = 200 + mock_response.json = MagicMock(return_value=_PASSKEY_TOKEN_RESPONSE) + mock_post.return_value = mock_response + + await client.signin_with_passkey( + auth_session="session_xyz", + authn_response=_make_passkey_authn_response(), + ) + + args, kwargs = mock_post.call_args + assert "DPoP" not in kwargs.get("headers", {}) From d299dbff859f633e6bfc4646a1b7db735f761039 Mon Sep 17 00:00:00 2001 From: Sourav Basu Date: Fri, 5 Jun 2026 13:16:43 +0530 Subject: [PATCH 8/9] SDK-8780 Added review changes for integrating passkey sign-in with SDKs state handling --- .../auth_server/server_client.py | 66 ++++++- .../auth_types/__init__.py | 12 ++ .../tests/test_server_client.py | 166 ++++++++++++++++-- 3 files changed, 218 insertions(+), 26 deletions(-) diff --git a/src/auth0_server_python/auth_server/server_client.py b/src/auth0_server_python/auth_server/server_client.py index 110f33a..1803ea2 100644 --- a/src/auth0_server_python/auth_server/server_client.py +++ b/src/auth0_server_python/auth_server/server_client.py @@ -37,6 +37,7 @@ MfaRequirements, PasskeyAuthResponse, PasskeyLoginChallengeResponse, + PasskeyLoginResult, PasskeySignupChallengeResponse, PasskeyUserProfile, PasskeyTokenResponse, @@ -2653,20 +2654,22 @@ async def signin_with_passkey( scope: Optional[str] = None, audience: Optional[str] = None, dpop_key: Optional["jwk.JWK"] = None, - ) -> PasskeyTokenResponse: + ) -> PasskeyLoginResult: """ Completes passkey authentication by exchanging the WebAuthn assertion - for tokens (POST /oauth/token with webauthn grant). + for tokens and establishing a server-side session. This is step 2 of 2: call passkey_signup_challenge or passkey_login_challenge first to obtain auth_session and the WebAuthn challenge options. Uses Content-Type: application/json (required for nested authn_response). + Persists the session to the state store (same as complete_interactive_login). Args: auth_session: Session credential from passkey_signup_challenge or passkey_login_challenge. authn_response: Serialized WebAuthn credential from navigator.credentials.create/get. - store_options: Optional options for domain resolution and state store. + store_options: Options passed to the state store (e.g., request/response for cookies). + When None, session storage is skipped (stateless deployments). connection: Auth0 database connection name (realm). organization: Auth0 organization ID or name. scope: OAuth2 scope string. @@ -2676,11 +2679,12 @@ async def signin_with_passkey( (token_type: DPoP). Required when the tenant mandates DPoP binding. Returns: - PasskeyTokenResponse containing access_token, id_token, expires_in, etc. + PasskeyLoginResult containing state_data with user claims and token sets, + consistent with complete_interactive_login and login_with_custom_token_exchange. Raises: MissingRequiredArgumentError: If auth_session or authn_response is missing. - PasskeyError: If token exchange fails. + PasskeyError: If token exchange or session creation fails. """ if not auth_session: raise MissingRequiredArgumentError("auth_session") @@ -2755,9 +2759,57 @@ async def signin_with_passkey( if "expires_in" in token_data and "expires_at" not in token_data: token_data["expires_at"] = int(time.time()) + token_data["expires_in"] - return PasskeyTokenResponse.model_validate(token_data) + token_response = PasskeyTokenResponse.model_validate(token_data) + + # Extract user claims from ID token if present + user_claims = None + sid = PKCE.generate_random_string(32) + if token_response.id_token: + jwks = await self._get_jwks_cached(domain, metadata) + try: + claims = await self._verify_and_decode_jwt( + token_response.id_token, jwks, audience=self._client_id + ) + origin_issuer = metadata.get("issuer") + token_issuer = claims.get("iss", "") + if self._normalize_url(token_issuer) != self._normalize_url(origin_issuer): + raise IssuerValidationError( + "ID token issuer mismatch. Ensure your Auth0 domain is configured correctly." + ) + user_claims = UserClaims.parse_obj(claims) + sid = claims.get("sid", sid) + except ValueError as e: + raise ApiError("jwks_key_not_found", str(e)) + except jwt.InvalidSignatureError as e: + raise ApiError("invalid_signature", f"ID token signature verification failed: {str(e)}", e) + except jwt.InvalidAudienceError as e: + raise ApiError("invalid_audience", f"ID token audience mismatch: {str(e)}", e) + except jwt.ExpiredSignatureError as e: + raise ApiError("token_expired", f"ID token has expired: {str(e)}", e) + except jwt.InvalidTokenError as e: + raise ApiError("invalid_token", f"ID token verification failed: {str(e)}", e) + + # Build token set and session state + token_set = TokenSet( + audience=audience or self.DEFAULT_AUDIENCE_STATE_KEY, + access_token=token_response.access_token, + scope=token_response.scope or scope or "", + expires_at=token_response.expires_at or int(time.time()) + token_response.expires_in, + ) + state_data = StateData( + user=user_claims, + id_token=token_response.id_token, + refresh_token=token_response.refresh_token, + token_sets=[token_set], + domain=domain, + internal={"sid": sid, "created_at": int(time.time())}, + ) + + await self._state_store.set(self._state_identifier, state_data, options=store_options) + + return PasskeyLoginResult(state_data=state_data.dict()) except Exception as e: - if isinstance(e, (PasskeyError, MissingRequiredArgumentError, ValidationError)): + if isinstance(e, (PasskeyError, MissingRequiredArgumentError, ValidationError, ApiError, IssuerValidationError)): raise raise PasskeyError(PasskeyErrorCode.TOKEN_EXCHANGE_FAILED, "Passkey sign-in failed", e) from e diff --git a/src/auth0_server_python/auth_types/__init__.py b/src/auth0_server_python/auth_types/__init__.py index 9494a22..c6caf8c 100644 --- a/src/auth0_server_python/auth_types/__init__.py +++ b/src/auth0_server_python/auth_types/__init__.py @@ -380,6 +380,18 @@ class LoginWithCustomTokenExchangeResult(BaseModel): authorization_details: Optional[list[AuthorizationDetails]] = None +class PasskeyLoginResult(BaseModel): + """ + Result from signin_with_passkey. + + Contains the session data established after the webauthn token exchange. + Mirrors LoginWithCustomTokenExchangeResult — passkey sign-in is a complete + login ceremony and creates a server-side session like every other login path. + """ + + state_data: dict[str, Any] + + # ============================================================================= # Connected Accounts Types # ============================================================================= diff --git a/src/auth0_server_python/tests/test_server_client.py b/src/auth0_server_python/tests/test_server_client.py index f987567..88d565a 100644 --- a/src/auth0_server_python/tests/test_server_client.py +++ b/src/auth0_server_python/tests/test_server_client.py @@ -5154,21 +5154,25 @@ async def test_passkey_login_challenge_network_error(mocker): @pytest.mark.asyncio async def test_signin_with_passkey_success(mocker): - from auth0_server_python.auth_types import PasskeyTokenResponse - from auth0_server_python.error import PasskeyError + from auth0_server_python.auth_types import PasskeyLoginResult + state_store = AsyncMock() client = ServerClient( domain="auth0.local", client_id="test_client_id", client_secret="test_client_secret", - state_store=AsyncMock(), + state_store=state_store, transaction_store=AsyncMock(), secret="test-secret-value", ) mocker.patch.object( client, "_get_oidc_metadata_cached", - return_value={"token_endpoint": "https://auth0.local/oauth/token"}, + return_value={"token_endpoint": "https://auth0.local/oauth/token", "issuer": "https://auth0.local/"}, ) + mocker.patch.object(client, "_get_jwks_cached", return_value={}) + mocker.patch.object(client, "_verify_and_decode_jwt", return_value={ + "sub": "auth0|user123", "name": "Jane", "iss": "https://auth0.local/", "sid": "sid_abc" + }) mock_post = mocker.patch("httpx.AsyncClient.post", new_callable=AsyncMock) mock_response = AsyncMock() mock_response.status_code = 200 @@ -5185,10 +5189,13 @@ async def test_signin_with_passkey_success(mocker): organization="org_abc", ) - assert isinstance(result, PasskeyTokenResponse) - assert result.access_token == "at_passkey_123" - assert result.token_type == "Bearer" - assert abs(result.expires_at - (int(time.time()) + 86400)) <= 2 + assert isinstance(result, PasskeyLoginResult) + assert "token_sets" in result.state_data + assert result.state_data["token_sets"][0]["access_token"] == "at_passkey_123" + assert result.state_data["token_sets"][0]["audience"] == "https://api.example.com" + + # Session must be persisted + state_store.set.assert_awaited_once() mock_post.assert_awaited_once() args, kwargs = mock_post.call_args @@ -5219,8 +5226,12 @@ async def test_signin_with_passkey_uses_json_content_type(mocker): mocker.patch.object( client, "_get_oidc_metadata_cached", - return_value={"token_endpoint": "https://auth0.local/oauth/token"}, + return_value={"token_endpoint": "https://auth0.local/oauth/token", "issuer": "https://auth0.local/"}, ) + mocker.patch.object(client, "_get_jwks_cached", return_value={}) + mocker.patch.object(client, "_verify_and_decode_jwt", return_value={ + "sub": "auth0|user123", "iss": "https://auth0.local/" + }) mock_post = mocker.patch("httpx.AsyncClient.post", new_callable=AsyncMock) mock_response = AsyncMock() mock_response.status_code = 200 @@ -5365,8 +5376,12 @@ async def test_signin_with_passkey_no_client_secret(mocker): mocker.patch.object( client, "_get_oidc_metadata_cached", - return_value={"token_endpoint": "https://auth0.local/oauth/token"}, + return_value={"token_endpoint": "https://auth0.local/oauth/token", "issuer": "https://auth0.local/"}, ) + mocker.patch.object(client, "_get_jwks_cached", return_value={}) + mocker.patch.object(client, "_verify_and_decode_jwt", return_value={ + "sub": "auth0|user123", "iss": "https://auth0.local/" + }) mock_post = mocker.patch("httpx.AsyncClient.post", new_callable=AsyncMock) mock_response = AsyncMock() mock_response.status_code = 200 @@ -5434,8 +5449,12 @@ async def test_signin_with_passkey_preserves_server_expires_at(mocker): mocker.patch.object( client, "_get_oidc_metadata_cached", - return_value={"token_endpoint": "https://auth0.local/oauth/token"}, + return_value={"token_endpoint": "https://auth0.local/oauth/token", "issuer": "https://auth0.local/"}, ) + mocker.patch.object(client, "_get_jwks_cached", return_value={}) + mocker.patch.object(client, "_verify_and_decode_jwt", return_value={ + "sub": "auth0|user123", "iss": "https://auth0.local/" + }) mock_post = mocker.patch("httpx.AsyncClient.post", new_callable=AsyncMock) mock_response = AsyncMock() mock_response.status_code = 200 @@ -5447,12 +5466,11 @@ async def test_signin_with_passkey_preserves_server_expires_at(mocker): }) mock_post.return_value = mock_response - from auth0_server_python.auth_types import PasskeyTokenResponse result = await client.signin_with_passkey( auth_session="session", authn_response=_make_passkey_authn_response(), ) - assert result.expires_at == 9999999999 + assert result.state_data["token_sets"][0]["expires_at"] == 9999999999 @pytest.mark.asyncio @@ -5468,8 +5486,12 @@ async def test_signin_with_passkey_missing_expires_at_calculates(mocker): mocker.patch.object( client, "_get_oidc_metadata_cached", - return_value={"token_endpoint": "https://auth0.local/oauth/token"}, + return_value={"token_endpoint": "https://auth0.local/oauth/token", "issuer": "https://auth0.local/"}, ) + mocker.patch.object(client, "_get_jwks_cached", return_value={}) + mocker.patch.object(client, "_verify_and_decode_jwt", return_value={ + "sub": "auth0|user123", "iss": "https://auth0.local/" + }) mock_post = mocker.patch("httpx.AsyncClient.post", new_callable=AsyncMock) mock_response = AsyncMock() mock_response.status_code = 200 @@ -5484,7 +5506,7 @@ async def test_signin_with_passkey_missing_expires_at_calculates(mocker): auth_session="session", authn_response=_make_passkey_authn_response(), ) - assert abs(result.expires_at - (int(time.time()) + 60)) <= 2 + assert abs(result.state_data["token_sets"][0]["expires_at"] - (int(time.time()) + 60)) <= 2 @pytest.mark.asyncio @@ -5503,8 +5525,12 @@ async def test_signin_with_passkey_dpop_attaches_proof_header(mocker): mocker.patch.object( client, "_get_oidc_metadata_cached", - return_value={"token_endpoint": "https://auth0.local/oauth/token"}, + return_value={"token_endpoint": "https://auth0.local/oauth/token", "issuer": "https://auth0.local/"}, ) + mocker.patch.object(client, "_get_jwks_cached", return_value={}) + mocker.patch.object(client, "_verify_and_decode_jwt", return_value={ + "sub": "auth0|user123", "iss": "https://auth0.local/" + }) mock_post = mocker.patch("httpx.AsyncClient.post", new_callable=AsyncMock) mock_response = AsyncMock() mock_response.status_code = 200 @@ -5548,8 +5574,12 @@ async def test_signin_with_passkey_dpop_nonce_retry(mocker): mocker.patch.object( client, "_get_oidc_metadata_cached", - return_value={"token_endpoint": "https://auth0.local/oauth/token"}, + return_value={"token_endpoint": "https://auth0.local/oauth/token", "issuer": "https://auth0.local/"}, ) + mocker.patch.object(client, "_get_jwks_cached", return_value={}) + mocker.patch.object(client, "_verify_and_decode_jwt", return_value={ + "sub": "auth0|user123", "iss": "https://auth0.local/" + }) mock_post = mocker.patch("httpx.AsyncClient.post", new_callable=AsyncMock) nonce_response = AsyncMock() @@ -5571,7 +5601,7 @@ async def test_signin_with_passkey_dpop_nonce_retry(mocker): ) assert mock_post.await_count == 2 - assert result.access_token == "at_passkey_123" + assert result.state_data["token_sets"][0]["access_token"] == "at_passkey_123" # Second call must include the nonce in the DPoP proof second_call_kwargs = mock_post.call_args_list[1][1] @@ -5595,8 +5625,12 @@ async def test_signin_with_passkey_without_dpop_no_dpop_header(mocker): mocker.patch.object( client, "_get_oidc_metadata_cached", - return_value={"token_endpoint": "https://auth0.local/oauth/token"}, + return_value={"token_endpoint": "https://auth0.local/oauth/token", "issuer": "https://auth0.local/"}, ) + mocker.patch.object(client, "_get_jwks_cached", return_value={}) + mocker.patch.object(client, "_verify_and_decode_jwt", return_value={ + "sub": "auth0|user123", "iss": "https://auth0.local/" + }) mock_post = mocker.patch("httpx.AsyncClient.post", new_callable=AsyncMock) mock_response = AsyncMock() mock_response.status_code = 200 @@ -5610,3 +5644,97 @@ async def test_signin_with_passkey_without_dpop_no_dpop_header(mocker): args, kwargs = mock_post.call_args assert "DPoP" not in kwargs.get("headers", {}) + + +@pytest.mark.asyncio +async def test_signin_with_passkey_creates_session_in_state_store(mocker): + """signin_with_passkey must persist a session — consistent with complete_interactive_login.""" + from auth0_server_python.auth_types import PasskeyLoginResult + state_store = AsyncMock() + client = ServerClient( + domain="auth0.local", + client_id="test_client_id", + client_secret="test_client_secret", + state_store=state_store, + transaction_store=AsyncMock(), + secret="test-secret-value", + ) + mocker.patch.object( + client, + "_get_oidc_metadata_cached", + return_value={"token_endpoint": "https://auth0.local/oauth/token", "issuer": "https://auth0.local/"}, + ) + mocker.patch.object(client, "_get_jwks_cached", return_value={}) + mocker.patch.object(client, "_verify_and_decode_jwt", return_value={ + "sub": "auth0|user123", + "name": "Jane Doe", + "email": "jane@example.com", + "iss": "https://auth0.local/", + "sid": "session_sid_abc", + }) + mock_post = mocker.patch("httpx.AsyncClient.post", new_callable=AsyncMock) + mock_response = AsyncMock() + mock_response.status_code = 200 + mock_response.json = MagicMock(return_value=_PASSKEY_TOKEN_RESPONSE) + mock_post.return_value = mock_response + + result = await client.signin_with_passkey( + auth_session="session_xyz", + authn_response=_make_passkey_authn_response(), + ) + + # State store must be called exactly once + state_store.set.assert_awaited_once() + + # Result must be PasskeyLoginResult, not bare tokens + assert isinstance(result, PasskeyLoginResult) + + # State data must contain user, token_sets, domain, internal + sd = result.state_data + assert sd["user"]["sub"] == "auth0|user123" + assert sd["user"]["name"] == "Jane Doe" + assert sd["token_sets"][0]["access_token"] == "at_passkey_123" + assert sd["id_token"] == "eyJ.test.jwt" + assert sd["refresh_token"] is None + assert sd["domain"] == "auth0.local" + assert sd["internal"]["sid"] == "session_sid_abc" + assert "created_at" in sd["internal"] + + +@pytest.mark.asyncio +async def test_signin_with_passkey_session_without_id_token(mocker): + """When no id_token is returned, session is still created with user=None.""" + from auth0_server_python.auth_types import PasskeyLoginResult + state_store = AsyncMock() + client = ServerClient( + domain="auth0.local", + client_id="test_client_id", + client_secret="test_client_secret", + state_store=state_store, + transaction_store=AsyncMock(), + secret="test-secret-value", + ) + mocker.patch.object( + client, + "_get_oidc_metadata_cached", + return_value={"token_endpoint": "https://auth0.local/oauth/token", "issuer": "https://auth0.local/"}, + ) + mock_post = mocker.patch("httpx.AsyncClient.post", new_callable=AsyncMock) + mock_response = AsyncMock() + mock_response.status_code = 200 + mock_response.json = MagicMock(return_value={ + "access_token": "at_no_id_token", + "token_type": "Bearer", + "expires_in": 3600, + }) + mock_post.return_value = mock_response + + result = await client.signin_with_passkey( + auth_session="session_xyz", + authn_response=_make_passkey_authn_response(), + ) + + assert isinstance(result, PasskeyLoginResult) + state_store.set.assert_awaited_once() + assert result.state_data["user"] is None + assert result.state_data["token_sets"][0]["access_token"] == "at_no_id_token" From 0aacc8ebd0590121d5b6be707eb92216b571321e Mon Sep 17 00:00:00 2001 From: Sourav Basu Date: Fri, 5 Jun 2026 22:29:02 +0530 Subject: [PATCH 9/9] PR Review Changes --- .../auth_schemes/dpop_auth.py | 3 +- .../auth_server/my_account_client.py | 19 +-- .../auth_server/server_client.py | 28 +++- .../auth_types/__init__.py | 24 ++- .../tests/test_my_account_client.py | 73 +++++++++ .../tests/test_server_client.py | 143 +++++++++++++++++- 6 files changed, 254 insertions(+), 36 deletions(-) diff --git a/src/auth0_server_python/auth_schemes/dpop_auth.py b/src/auth0_server_python/auth_schemes/dpop_auth.py index 0bf2d66..a0e0a19 100644 --- a/src/auth0_server_python/auth_schemes/dpop_auth.py +++ b/src/auth0_server_python/auth_schemes/dpop_auth.py @@ -61,8 +61,7 @@ def auth_flow(self, request: httpx.Request): # RFC 9449 §8.2 — server-nonce retry if ( - response is not None - and response.status_code == 401 + response.status_code == 401 and response.headers.get("DPoP-Nonce") ): nonce = response.headers["DPoP-Nonce"] diff --git a/src/auth0_server_python/auth_server/my_account_client.py b/src/auth0_server_python/auth_server/my_account_client.py index 5ffadd9..5e10b60 100644 --- a/src/auth0_server_python/auth_server/my_account_client.py +++ b/src/auth0_server_python/auth_server/my_account_client.py @@ -1,8 +1,9 @@ import json from typing import TYPE_CHECKING, Optional -from urllib.parse import quote, unquote +from urllib.parse import quote, unquote, urlparse import httpx + from auth0_server_python.auth_schemes.bearer_auth import BearerAuth from auth0_server_python.auth_schemes.dpop_auth import DPoPAuth from auth0_server_python.auth_types import ( @@ -654,12 +655,12 @@ async def enroll_authentication_method( if not location: raise ApiError( "enroll_authentication_method_error", - "Enrollment succeeded (201) but Location header is missing", + "Enrollment succeeded (202) but Location header is missing", ) - path = location.split("?")[0].split("#")[0].rstrip("/") - segments = path.split("/") - authentication_method_id = unquote(segments[-1]) if len(segments) > 1 else "" + parsed_path = urlparse(location).path.rstrip("/") + raw_id = parsed_path.rsplit("/", 1)[-1] if "/" in parsed_path else "" + authentication_method_id = unquote(raw_id) if not authentication_method_id or authentication_method_id in ( "authentication-methods", "v1", @@ -667,7 +668,7 @@ async def enroll_authentication_method( ): raise ApiError( "enroll_authentication_method_error", - "Enrollment succeeded (201) but could not extract ID from Location header", + "Enrollment succeeded (202) but could not extract ID from Location header", ) try: @@ -675,21 +676,21 @@ async def enroll_authentication_method( except (json.JSONDecodeError, ValueError): raise ApiError( "enroll_authentication_method_error", - "Enrollment succeeded (201) but response body is not valid JSON", + "Enrollment succeeded (202) but response body is not valid JSON", ) auth_session = data.get("auth_session") if not auth_session: raise ApiError( "enroll_authentication_method_error", - "Enrollment succeeded (201) but auth_session is missing from response", + "Enrollment succeeded (202) but auth_session is missing from response", ) return EnrollmentChallengeResponse.model_validate( { + **data, "authentication_method_id": authentication_method_id, "auth_session": auth_session, - "authn_params_public_key": data.get("authn_params_public_key"), } ) diff --git a/src/auth0_server_python/auth_server/server_client.py b/src/auth0_server_python/auth_server/server_client.py index 1803ea2..22f6d21 100644 --- a/src/auth0_server_python/auth_server/server_client.py +++ b/src/auth0_server_python/auth_server/server_client.py @@ -16,10 +16,10 @@ import httpx import jwt from authlib.integrations.base_client.errors import OAuthError -from auth0_server_python.auth_schemes.dpop_auth import DPoPAuth, make_dpop_proof_for_token_endpoint from authlib.integrations.httpx_client import AsyncOAuth2Client from pydantic import ValidationError +from auth0_server_python.auth_schemes.dpop_auth import make_dpop_proof_for_token_endpoint from auth0_server_python.auth_server.mfa_client import MfaClient from auth0_server_python.auth_server.my_account_client import MyAccountClient from auth0_server_python.auth_types import ( @@ -39,8 +39,8 @@ PasskeyLoginChallengeResponse, PasskeyLoginResult, PasskeySignupChallengeResponse, - PasskeyUserProfile, PasskeyTokenResponse, + PasskeyUserProfile, StartInteractiveLoginOptions, StateData, TokenExchangeResponse, @@ -2508,6 +2508,7 @@ async def passkey_signup_challenge( user_profile: Optional[PasskeyUserProfile] = None, connection: Optional[str] = None, organization: Optional[str] = None, + user_metadata: Optional[dict[str, Any]] = None, store_options: Optional[dict[str, Any]] = None, ) -> PasskeySignupChallengeResponse: """ @@ -2521,6 +2522,8 @@ async def passkey_signup_challenge( Use PasskeyUserProfile — supports extra fields for forward compatibility. connection: Auth0 database connection name (realm). organization: Auth0 organization ID or name. + user_metadata: Optional custom metadata added at the root of the request body, + not nested inside user_profile (per Auth0 API spec). store_options: Optional options for domain resolution. Returns: @@ -2537,6 +2540,8 @@ async def passkey_signup_challenge( body["client_secret"] = self._client_secret if user_profile: body["user_profile"] = user_profile.model_dump(exclude_none=True) + if user_metadata: + body["user_metadata"] = user_metadata if connection: body["realm"] = connection if organization: @@ -2669,7 +2674,7 @@ async def signin_with_passkey( auth_session: Session credential from passkey_signup_challenge or passkey_login_challenge. authn_response: Serialized WebAuthn credential from navigator.credentials.create/get. store_options: Options passed to the state store (e.g., request/response for cookies). - When None, session storage is skipped (stateless deployments). + Passed through to the store on every call. connection: Auth0 database connection name (realm). organization: Auth0 organization ID or name. scope: OAuth2 scope string. @@ -2761,6 +2766,13 @@ async def signin_with_passkey( token_response = PasskeyTokenResponse.model_validate(token_data) + if dpop_key is not None and token_response.token_type.lower() != "dpop": + raise PasskeyError( + PasskeyErrorCode.TOKEN_EXCHANGE_FAILED, + f"DPoP token binding failed: expected token_type 'DPoP', " + f"got '{token_response.token_type}'", + ) + # Extract user claims from ID token if present user_claims = None sid = PKCE.generate_random_string(32) @@ -2771,12 +2783,16 @@ async def signin_with_passkey( token_response.id_token, jwks, audience=self._client_id ) origin_issuer = metadata.get("issuer") + if not origin_issuer: + raise IssuerValidationError( + "Issuer missing from OIDC metadata. Cannot validate ID token issuer." + ) token_issuer = claims.get("iss", "") if self._normalize_url(token_issuer) != self._normalize_url(origin_issuer): raise IssuerValidationError( "ID token issuer mismatch. Ensure your Auth0 domain is configured correctly." ) - user_claims = UserClaims.parse_obj(claims) + user_claims = UserClaims.model_validate(claims) sid = claims.get("sid", sid) except ValueError as e: raise ApiError("jwks_key_not_found", str(e)) @@ -2794,7 +2810,7 @@ async def signin_with_passkey( audience=audience or self.DEFAULT_AUDIENCE_STATE_KEY, access_token=token_response.access_token, scope=token_response.scope or scope or "", - expires_at=token_response.expires_at or int(time.time()) + token_response.expires_in, + expires_at=token_response.expires_at if token_response.expires_at is not None else int(time.time()) + token_response.expires_in, ) state_data = StateData( user=user_claims, @@ -2807,7 +2823,7 @@ async def signin_with_passkey( await self._state_store.set(self._state_identifier, state_data, options=store_options) - return PasskeyLoginResult(state_data=state_data.dict()) + return PasskeyLoginResult(state_data=state_data.model_dump()) except Exception as e: if isinstance(e, (PasskeyError, MissingRequiredArgumentError, ValidationError, ApiError, IssuerValidationError)): diff --git a/src/auth0_server_python/auth_types/__init__.py b/src/auth0_server_python/auth_types/__init__.py index c6caf8c..44ec918 100644 --- a/src/auth0_server_python/auth_types/__init__.py +++ b/src/auth0_server_python/auth_types/__init__.py @@ -688,7 +688,7 @@ class PasskeyAuthenticatorSelection(BaseModel): class PasskeyPublicKeyOptions(BaseModel): - model_config = ConfigDict(populate_by_name=True) + model_config = ConfigDict(populate_by_name=True, extra="allow") challenge: str rp: Optional[PasskeyRpInfo] = None rp_id: Optional[str] = Field(None, alias="rpId") @@ -717,6 +717,7 @@ class EnrollAuthenticationMethodRequest(BaseModel): class EnrollmentChallengeResponse(BaseModel): + model_config = ConfigDict(extra="allow") authentication_method_id: str auth_session: str authn_params_public_key: Optional[PasskeyPublicKeyOptions] = None @@ -737,7 +738,7 @@ class PasskeyAuthResponse(BaseModel): type: str authenticator_attachment: Optional[str] = Field(None, alias="authenticatorAttachment") response: dict[str, str] - client_extension_results: Optional[dict] = Field(None, alias="clientExtensionResults") + client_extension_results: Optional[dict[str, Any]] = Field(None, alias="clientExtensionResults") class VerifyAuthenticationMethodRequest(BaseModel): @@ -803,31 +804,26 @@ class PasskeyUserProfile(BaseModel): family_name: Optional[str] = None nickname: Optional[str] = None picture: Optional[str] = None - user_metadata: Optional[dict[str, Any]] = None -class PasskeySignupChallengeResponse(BaseModel): +class _PasskeyChallengeResponseBase(BaseModel): auth_session: str authn_params_public_key: PasskeyPublicKeyOptions def __repr__(self) -> str: return ( - f"PasskeySignupChallengeResponse(" + f"{self.__class__.__name__}(" f"auth_session=[REDACTED], " f"authn_params_public_key={self.authn_params_public_key!r})" ) -class PasskeyLoginChallengeResponse(BaseModel): - auth_session: str - authn_params_public_key: PasskeyPublicKeyOptions +class PasskeySignupChallengeResponse(_PasskeyChallengeResponseBase): + pass - def __repr__(self) -> str: - return ( - f"PasskeyLoginChallengeResponse(" - f"auth_session=[REDACTED], " - f"authn_params_public_key={self.authn_params_public_key!r})" - ) + +class PasskeyLoginChallengeResponse(_PasskeyChallengeResponseBase): + pass class PasskeyTokenResponse(BaseModel): diff --git a/src/auth0_server_python/tests/test_my_account_client.py b/src/auth0_server_python/tests/test_my_account_client.py index da2875d..6f254c1 100644 --- a/src/auth0_server_python/tests/test_my_account_client.py +++ b/src/auth0_server_python/tests/test_my_account_client.py @@ -804,6 +804,36 @@ async def test_enroll_authentication_method_success(mocker): assert result.authn_params_public_key.user.display_name == "Test User" +@pytest.mark.asyncio +async def test_enroll_authentication_method_public_key_extra_fields_preserved(mocker): + """Unknown WebAuthn fields (excludeCredentials, attestation, extensions) must not be dropped.""" + client = MyAccountClient(domain="auth0.local") + response = AsyncMock() + response.status_code = 202 + response.headers = {"location": "/me/v1/authentication-methods/passkey|new"} + response.json = MagicMock(return_value={ + "auth_session": "session_abc", + "authn_params_public_key": { + "challenge": "dGVzdA", + "rp": {"id": "auth0.local", "name": "My App"}, + "user": {"id": "dXNlcl8x", "name": "user@test.com"}, + "pubKeyCredParams": [{"type": "public-key", "alg": -7}], + "excludeCredentials": [{"type": "public-key", "id": "Y3JlZA"}], + "attestation": "direct", + "extensions": {"appid": "https://auth0.local"}, + }, + }) + mocker.patch("httpx.AsyncClient.post", new_callable=AsyncMock, return_value=response) + + req = EnrollAuthenticationMethodRequest(type="passkey") + result = await client.enroll_authentication_method(access_token="token123", request=req) + + pk = result.authn_params_public_key + assert pk.model_extra["excludeCredentials"] == [{"type": "public-key", "id": "Y3JlZA"}] + assert pk.model_extra["attestation"] == "direct" + assert pk.model_extra["extensions"] == {"appid": "https://auth0.local"} + + @pytest.mark.asyncio async def test_enroll_authentication_method_missing_location(mocker): client = MyAccountClient(domain="auth0.local") @@ -848,6 +878,49 @@ async def test_enroll_authentication_method_location_absolute_url(mocker): assert result.authentication_method_id == "am_xyz" +@pytest.mark.asyncio +async def test_enroll_authentication_method_totp_preserves_secret(mocker): + """TOTP enrollment response includes totp_secret and barcode_uri — must not be dropped.""" + client = MyAccountClient(domain="auth0.local") + response = AsyncMock() + response.status_code = 202 + response.headers = {"location": "/me/v1/authentication-methods/totp|new"} + response.json = MagicMock(return_value={ + "auth_session": "session_totp", + "totp_secret": "JBSWY3DPEHPK3PXP", + "barcode_uri": "otpauth://totp/Example:alice@example.com?secret=JBSWY3DPEHPK3PXP", + }) + mocker.patch("httpx.AsyncClient.post", new_callable=AsyncMock, return_value=response) + + req = EnrollAuthenticationMethodRequest(type="totp") + result = await client.enroll_authentication_method(access_token="token123", request=req) + + assert result.authentication_method_id == "totp|new" + assert result.auth_session == "session_totp" + assert result.model_extra["totp_secret"] == "JBSWY3DPEHPK3PXP" + assert result.model_extra["barcode_uri"].startswith("otpauth://") + + +@pytest.mark.asyncio +async def test_enroll_authentication_method_oob_preserves_oob_code(mocker): + """OOB (email/phone) enrollment response includes oob_code — must not be dropped.""" + client = MyAccountClient(domain="auth0.local") + response = AsyncMock() + response.status_code = 202 + response.headers = {"location": "/me/v1/authentication-methods/email|new"} + response.json = MagicMock(return_value={ + "auth_session": "session_oob", + "oob_code": "oob_abc123", + }) + mocker.patch("httpx.AsyncClient.post", new_callable=AsyncMock, return_value=response) + + req = EnrollAuthenticationMethodRequest(type="email") + result = await client.enroll_authentication_method(access_token="token123", request=req) + + assert result.authentication_method_id == "email|new" + assert result.model_extra["oob_code"] == "oob_abc123" + + @pytest.mark.asyncio async def test_verify_authentication_method_success(mocker): client = MyAccountClient(domain="auth0.local") diff --git a/src/auth0_server_python/tests/test_server_client.py b/src/auth0_server_python/tests/test_server_client.py index 88d565a..3c271ea 100644 --- a/src/auth0_server_python/tests/test_server_client.py +++ b/src/auth0_server_python/tests/test_server_client.py @@ -4855,6 +4855,14 @@ async def _fake_fetch(self, domain): "scope": "openid profile", } +_PASSKEY_TOKEN_RESPONSE_DPOP = { + "access_token": "at_passkey_dpop_123", + "id_token": "eyJ.test.jwt", + "token_type": "DPoP", + "expires_in": 86400, + "scope": "openid profile", +} + def _make_passkey_authn_response(): from auth0_server_python.auth_types import PasskeyAuthResponse @@ -4939,8 +4947,8 @@ async def test_passkey_signup_challenge_user_profile_fields(mocker): family_name="Doe", nickname="jd", picture="https://example.com/pic.jpg", - user_metadata={"role": "admin"}, ), + user_metadata={"role": "admin"}, organization="org_123", ) @@ -4953,7 +4961,8 @@ async def test_passkey_signup_challenge_user_profile_fields(mocker): assert body["user_profile"]["family_name"] == "Doe" assert body["user_profile"]["nickname"] == "jd" assert body["user_profile"]["picture"] == "https://example.com/pic.jpg" - assert body["user_profile"]["user_metadata"] == {"role": "admin"} + assert "user_metadata" not in body["user_profile"] + assert body["user_metadata"] == {"role": "admin"} assert body["organization"] == "org_123" @@ -4979,10 +4988,38 @@ async def test_passkey_signup_challenge_minimal_body(mocker): body = kwargs["json"] assert body == {"client_id": "test_client_id", "client_secret": "test_client_secret"} assert "user_profile" not in body + assert "user_metadata" not in body assert "realm" not in body assert "organization" not in body +@pytest.mark.asyncio +async def test_passkey_signup_challenge_user_metadata_root_level(mocker): + """user_metadata must be sent at root level, not nested inside user_profile.""" + client = ServerClient( + domain="auth0.local", + client_id="test_client_id", + client_secret="test_client_secret", + state_store=AsyncMock(), + transaction_store=AsyncMock(), + secret="test-secret-value", + ) + mock_post = mocker.patch("httpx.AsyncClient.post", new_callable=AsyncMock) + mock_response = AsyncMock() + mock_response.status_code = 200 + mock_response.json = MagicMock(return_value=_PASSKEY_SIGNUP_CHALLENGE_RESPONSE) + mock_post.return_value = mock_response + + await client.passkey_signup_challenge( + user_metadata={"preferred_language": "en"}, + ) + + args, kwargs = mock_post.call_args + body = kwargs["json"] + assert body["user_metadata"] == {"preferred_language": "en"} + assert "user_profile" not in body + + @pytest.mark.asyncio async def test_passkey_signup_challenge_api_error(mocker): from auth0_server_python.auth_types import PasskeyUserProfile @@ -5085,6 +5122,34 @@ async def test_passkey_login_challenge_success(mocker): assert body["client_id"] == "test_client_id" assert body["realm"] == "Username-Password-Authentication" assert body["organization"] == "org_abc" + assert "username" not in body + + +@pytest.mark.asyncio +async def test_passkey_login_challenge_minimal_body(mocker): + """No optional fields sent when called with no arguments.""" + client = ServerClient( + domain="auth0.local", + client_id="test_client_id", + client_secret="test_client_secret", + state_store=AsyncMock(), + transaction_store=AsyncMock(), + secret="test-secret-value", + ) + mock_post = mocker.patch("httpx.AsyncClient.post", new_callable=AsyncMock) + mock_response = AsyncMock() + mock_response.status_code = 200 + mock_response.json = MagicMock(return_value=_PASSKEY_LOGIN_CHALLENGE_RESPONSE) + mock_post.return_value = mock_response + + await client.passkey_login_challenge() + + args, kwargs = mock_post.call_args + body = kwargs["json"] + assert body == {"client_id": "test_client_id", "client_secret": "test_client_secret"} + assert "username" not in body + assert "realm" not in body + assert "organization" not in body @pytest.mark.asyncio @@ -5534,7 +5599,7 @@ async def test_signin_with_passkey_dpop_attaches_proof_header(mocker): mock_post = mocker.patch("httpx.AsyncClient.post", new_callable=AsyncMock) mock_response = AsyncMock() mock_response.status_code = 200 - mock_response.json = MagicMock(return_value=_PASSKEY_TOKEN_RESPONSE) + mock_response.json = MagicMock(return_value=_PASSKEY_TOKEN_RESPONSE_DPOP) mock_post.return_value = mock_response dpop_key = jwk_module.JWK.generate(kty="EC", crv="P-256") @@ -5589,7 +5654,7 @@ async def test_signin_with_passkey_dpop_nonce_retry(mocker): success_response = AsyncMock() success_response.status_code = 200 - success_response.json = MagicMock(return_value=_PASSKEY_TOKEN_RESPONSE) + success_response.json = MagicMock(return_value=_PASSKEY_TOKEN_RESPONSE_DPOP) mock_post.side_effect = [nonce_response, success_response] @@ -5601,7 +5666,7 @@ async def test_signin_with_passkey_dpop_nonce_retry(mocker): ) assert mock_post.await_count == 2 - assert result.state_data["token_sets"][0]["access_token"] == "at_passkey_123" + assert result.state_data["token_sets"][0]["access_token"] == "at_passkey_dpop_123" # Second call must include the nonce in the DPoP proof second_call_kwargs = mock_post.call_args_list[1][1] @@ -5612,6 +5677,74 @@ async def test_signin_with_passkey_dpop_nonce_retry(mocker): assert payload["nonce"] == "server-nonce-abc" +@pytest.mark.asyncio +async def test_signin_with_passkey_dpop_rejects_bearer_downgrade(mocker): + """Server returning token_type=Bearer when DPoP was requested must raise PasskeyError.""" + from auth0_server_python.error import PasskeyError + from jwcrypto import jwk as jwk_module + client = ServerClient( + domain="auth0.local", + client_id="test_client_id", + client_secret="test_client_secret", + state_store=AsyncMock(), + transaction_store=AsyncMock(), + secret="test-secret-value", + ) + mocker.patch.object( + client, + "_get_oidc_metadata_cached", + return_value={"token_endpoint": "https://auth0.local/oauth/token", "issuer": "https://auth0.local/"}, + ) + mock_post = mocker.patch("httpx.AsyncClient.post", new_callable=AsyncMock) + mock_response = AsyncMock() + mock_response.status_code = 200 + mock_response.json = MagicMock(return_value=_PASSKEY_TOKEN_RESPONSE) + mock_post.return_value = mock_response + + dpop_key = jwk_module.JWK.generate(kty="EC", crv="P-256") + with pytest.raises(PasskeyError) as exc: + await client.signin_with_passkey( + auth_session="session_xyz", + authn_response=_make_passkey_authn_response(), + dpop_key=dpop_key, + ) + assert "DPoP" in str(exc.value) or "token_type" in str(exc.value).lower() + + +@pytest.mark.asyncio +async def test_signin_with_passkey_missing_issuer_in_metadata(mocker): + """Missing 'issuer' in OIDC metadata must raise IssuerValidationError, not silently pass.""" + client = ServerClient( + domain="auth0.local", + client_id="test_client_id", + client_secret="test_client_secret", + state_store=AsyncMock(), + transaction_store=AsyncMock(), + secret="test-secret-value", + ) + mocker.patch.object( + client, + "_get_oidc_metadata_cached", + return_value={"token_endpoint": "https://auth0.local/oauth/token"}, + ) + mocker.patch.object(client, "_get_jwks_cached", return_value={}) + mocker.patch.object(client, "_verify_and_decode_jwt", return_value={ + "sub": "auth0|user123", "iss": "https://auth0.local/" + }) + mock_post = mocker.patch("httpx.AsyncClient.post", new_callable=AsyncMock) + mock_response = AsyncMock() + mock_response.status_code = 200 + mock_response.json = MagicMock(return_value=_PASSKEY_TOKEN_RESPONSE) + mock_post.return_value = mock_response + + with pytest.raises(Exception) as exc: + await client.signin_with_passkey( + auth_session="session_xyz", + authn_response=_make_passkey_authn_response(), + ) + assert "issuer" in str(exc.value).lower() + + @pytest.mark.asyncio async def test_signin_with_passkey_without_dpop_no_dpop_header(mocker): client = ServerClient(