diff --git a/examples/ConnectedAccounts.md b/examples/ConnectedAccounts.md new file mode 100644 index 0000000..b8f9135 --- /dev/null +++ b/examples/ConnectedAccounts.md @@ -0,0 +1,107 @@ +# Connect Accounts for using Token Vault + +The Connect Accounts feature uses the Auth0 My Account API to allow users to link multiple third party accounts to a single Auth0 user profile. In order to use this feature, [My Account API](https://auth0.com/docs/manage-users/my-account-api) must be activated on your Auth0 tenant. + +>[!NOTE] +>DPoP sender token constraining is not yet supported in this SDK. My Account API can be configured to support it (default behaviour) but must not be configured to require it. + + +When using Connected Accounts, Auth0 acquires tokens from upstream Identity Providers (like Google) and stores them in a secure [Token Vault](https://auth0.com/docs/secure/tokens/token-vault). These tokens can then be used to access third-party APIs (like Google Calendar) on behalf of the user. + +The tokens in the Token Vault are then accessible to [Applications](https://auth0.com/docs/get-started/applications) configured in Auth0. The application can issue requests to Auth0 to retrieve the tokens from the Token Vault and use them to access the third-party APIs. + +This is particularly useful for applications that require access to different resources on behalf of a user, like AI Agents. + +## Configure the SDK + +The Auth0 client Application must be configured to use refresh tokens and [MRRT (Multiple Resource Refresh Tokens)](https://auth0.com/docs/secure/tokens/refresh-tokens/multi-resource-refresh-token) since we will use the refresh token grant to get Access Tokens for the My Account API in addition to the API we are calling. + +```python +server_client = ServerClient( + domain="YOUR_AUTH0_DOMAIN", + client_id="YOUR_CLIENT_ID", + client_secret="YOUR_CLIENT_SECRET", + secret="YOUR_SECRET", + authorization_params={ + "redirect_uri":"YOUR_CALLBACK_URL", + "audience": "YOUR_API_IDENTIFIER" + } +) +``` + +## Login to the application + +Use the login methods to authenticate to the application and get a refresh and access token for the API. + +```python +# Login specifying any scopes for the Auth0 API + +authorization_url = await server_client.start_interactive_login( + { + "authorization_params": { + # must include offline_access to obtain a refresh token + "scope": "openid profile email offline_access" + } + }, + store_options={"request": request, "response": response} +) + +# redirect user + +# handle redirect +result = await server_client.complete_interactive_login( + callback_url, + store_options={"request": request, "response": response} +) +``` + +## Connect to a third-party account + +Start the flow using the `start_connect_account` method to redirect the user to the third-party Identity Provider to connect their account. + +The `authorization_params` is used to pass additional parameters required by the third-party IdP +The `app_state` parameter allows you to pass custom state (for example, a return URL) that is later available when the connect process completes. + +```python + +connect_url = await self.client.start_connect_account( + ConnectAccountOptions( + connection="CONNECTION", # e.g. google-oauth2 + redirect_uri="YOUR_CALLBACK_URL" + app_state= { + "returnUrl":"SOME_URL" + } + scopes= [ + # scopes to passed to the third-party IdP + "openid", + "email", + "profile" + "offline_access" + ] + authorization_params= { + # additional auth parameters to be sent to the third-party IdP e.g. + "login_hint": "user123", + "resource": "some_resource" + } + ), + store_options={"request": request, "response": response} +) +``` + +Using the url returned, redirect the user to the third-party Identity Provider to complete any required authorization. Once authorized, the user will be redirected back to the provided `redirect_uri` with a `connect_code` and `state` parameter. + +## Complete the account connection + +Call the `complete_connect_account` method using the full callback url returned from the third-party IdP to complete the connected account flow. This method extracts the connect_code from the URL, completes the connection, and returns the response data (including any `app_state` you passed originally). + +```python +complete_response = await self.client.complete_connect_account( + url= callback_url, + store_options=store_options +) +``` + +>[!NOTE] +>The `callback_url` must include the necessary parameters (`state` and `connect_code`) that Auth0 sends upon successful authentication. + +You can now call the API with your access token and the API can use [Access Token Exchange with Token Vault](https://auth0.com/docs/secure/tokens/token-vault/access-token-exchange-with-token-vault) to get tokens from the Token Vault to access third-party APIs on behalf of the user. \ No newline at end of file diff --git a/src/auth0_server_python/auth_schemes/__init__.py b/src/auth0_server_python/auth_schemes/__init__.py new file mode 100644 index 0000000..1c2c869 --- /dev/null +++ b/src/auth0_server_python/auth_schemes/__init__.py @@ -0,0 +1,3 @@ +from .bearer_auth import BearerAuth + +__all__ = ["BearerAuth"] diff --git a/src/auth0_server_python/auth_schemes/bearer_auth.py b/src/auth0_server_python/auth_schemes/bearer_auth.py new file mode 100644 index 0000000..8fd629e --- /dev/null +++ b/src/auth0_server_python/auth_schemes/bearer_auth.py @@ -0,0 +1,10 @@ +import httpx + + +class BearerAuth(httpx.Auth): + def __init__(self, token: str): + self.token = token + + def auth_flow(self, request): + request.headers['Authorization'] = f"Bearer {self.token}" + yield request diff --git a/src/auth0_server_python/auth_server/__init__.py b/src/auth0_server_python/auth_server/__init__.py index b95c7c0..72818be 100644 --- a/src/auth0_server_python/auth_server/__init__.py +++ b/src/auth0_server_python/auth_server/__init__.py @@ -1,3 +1,4 @@ +from .my_account_client import MyAccountClient from .server_client import ServerClient -__all__ = ["ServerClient"] +__all__ = ["ServerClient", "MyAccountClient"] diff --git a/src/auth0_server_python/auth_server/my_account_client.py b/src/auth0_server_python/auth_server/my_account_client.py new file mode 100644 index 0000000..a5a31d2 --- /dev/null +++ b/src/auth0_server_python/auth_server/my_account_client.py @@ -0,0 +1,94 @@ + +import httpx +from auth0_server_python.auth_schemes.bearer_auth import BearerAuth +from auth0_server_python.auth_types import ( + CompleteConnectAccountRequest, + CompleteConnectAccountResponse, + ConnectAccountRequest, + ConnectAccountResponse, +) +from auth0_server_python.error import ( + ApiError, + MyAccountApiError, +) + + +class MyAccountClient: + def __init__(self, domain: str): + self._domain = domain + + @property + def audience(self): + return f"https://{self._domain}/me/" + + async def connect_account( + self, + access_token: str, + request: ConnectAccountRequest + ) -> ConnectAccountResponse: + try: + async with httpx.AsyncClient() as client: + response = await client.post( + url=f"{self.audience}v1/connected-accounts/connect", + json=request.model_dump(exclude_none=True), + auth=BearerAuth(access_token) + ) + + if response.status_code != 201: + error_data = response.json() + raise MyAccountApiError( + title=error_data.get("title", None), + type=error_data.get("type", None), + detail=error_data.get("detail", None), + status=error_data.get("status", None), + validation_errors=error_data.get("validation_errors", None) + ) + + data = response.json() + + return ConnectAccountResponse.model_validate(data) + + except Exception as e: + if isinstance(e, MyAccountApiError): + raise + raise ApiError( + "connect_account_error", + f"Connected Accounts connect request failed: {str(e) or 'Unknown error'}", + e + ) + + async def complete_connect_account( + self, + access_token: str, + request: CompleteConnectAccountRequest + ) -> CompleteConnectAccountResponse: + try: + async with httpx.AsyncClient() as client: + response = await client.post( + url=f"{self.audience}v1/connected-accounts/complete", + json=request.model_dump(exclude_none=True), + auth=BearerAuth(access_token) + ) + + if response.status_code != 201: + error_data = response.json() + raise MyAccountApiError( + title=error_data.get("title", None), + type=error_data.get("type", None), + detail=error_data.get("detail", None), + status=error_data.get("status", None), + validation_errors=error_data.get("validation_errors", None) + ) + + data = response.json() + + return CompleteConnectAccountResponse.model_validate(data) + + except Exception as e: + if isinstance(e, MyAccountApiError): + raise + raise ApiError( + "connect_account_error", + f"Connected Accounts complete request failed: {str(e) or 'Unknown error'}", + e + ) diff --git a/src/auth0_server_python/auth_server/server_client.py b/src/auth0_server_python/auth_server/server_client.py index e66680f..c968120 100644 --- a/src/auth0_server_python/auth_server/server_client.py +++ b/src/auth0_server_python/auth_server/server_client.py @@ -7,11 +7,16 @@ import json import time from typing import Any, Generic, Optional, TypeVar -from urllib.parse import parse_qs, urlparse +from urllib.parse import parse_qs, urlencode, urlparse, urlunparse import httpx import jwt +from auth0_server_python.auth_server.my_account_client import MyAccountClient from auth0_server_python.auth_types import ( + CompleteConnectAccountRequest, + CompleteConnectAccountResponse, + ConnectAccountOptions, + ConnectAccountRequest, LogoutOptions, LogoutTokenClaims, StartInteractiveLoginOptions, @@ -78,7 +83,6 @@ def __init__( transaction_identifier: Identifier for transaction data state_identifier: Identifier for state data authorization_params: Default parameters for authorization requests - pushed_authorization_requests: Whether to use PAR for authorization requests """ if not secret: raise MissingRequiredArgumentError("secret") @@ -103,6 +107,8 @@ def __init__( client_secret=client_secret, ) + self._my_account_client = MyAccountClient(domain=domain) + async def _fetch_oidc_metadata(self, domain: str) -> dict: metadata_url = f"https://{domain}/.well-known/openid-configuration" async with httpx.AsyncClient() as client: @@ -1333,3 +1339,135 @@ async def get_token_for_connection(self, options: dict[str, Any]) -> dict[str, A "There was an error while trying to retrieve an access token for a connection.", e ) + + async def start_connect_account( + self, + options: ConnectAccountOptions, + store_options: dict = None + ) -> str: + """ + Initiates the connect account flow for linking a third-party account to the user's profile. + + This method generates PKCE parameters, creates a transaction and calls the My Account API + to create a connect account request, returning /connect url containing a ticket. + + Args: + options: Options for retrieving an access token for a connection. + store_options: Optional options used to pass to the Transaction and State Store. + + Returns: + The a connect URL containing a ticket to redirect the user to. + """ + # Use the default redirect_uri if none is specified + redirect_uri = options.redirect_uri or self._redirect_uri + # Ensure we have a redirect_uri + if not redirect_uri: + raise MissingRequiredArgumentError("redirect_uri") + + # Generate PKCE code verifier and challenge + code_verifier = PKCE.generate_code_verifier() + code_challenge = PKCE.generate_code_challenge(code_verifier) + + state= PKCE.generate_random_string(32) + + connect_request = ConnectAccountRequest( + connection=options.connection, + scopes=options.scopes, + redirect_uri = redirect_uri, + code_challenge=code_challenge, + code_challenge_method="S256", + state=state, + authorization_params=options.authorization_params + ) + + access_token = await self.get_access_token( + audience=self._my_account_client.audience, + scope="create:me:connected_accounts", + store_options=store_options + ) + connect_response = await self._my_account_client.connect_account( + access_token=access_token, + request=connect_request + ) + + # Build the transaction data to store + transaction_data = TransactionData( + code_verifier=code_verifier, + app_state=options.app_state, + auth_session=connect_response.auth_session, + redirect_uri=redirect_uri + ) + + # Store the transaction data + await self._transaction_store.set( + f"{self._transaction_identifier}:{state}", + transaction_data, + options=store_options + ) + + parsedUrl = urlparse(connect_response.connect_uri) + query = urlencode({"ticket": connect_response.connect_params.ticket}) + return urlunparse((parsedUrl.scheme, parsedUrl.netloc, parsedUrl.path, parsedUrl.params, query, parsedUrl.fragment)) + + async def complete_connect_account( + self, + url: str, + store_options: dict = None + ) -> CompleteConnectAccountResponse: + """ + Handles the redirect callback to complete the connect account flow for linking a third-party + account to the user's profile. + + This works similar to the redirect from the login flow except it verifies the `connect_code` + with the My Account API rather than the `code` with the Authorization Server. + + Args: + url: The full callback URL including query parameters + store_options: Optional options used to pass to the Transaction and State Store. + + Returns: + A response from the connect account flow. + """ + # Parse the URL to get query parameters + parsed_url = urlparse(url) + query_params = parse_qs(parsed_url.query) + + # Get state parameter from the URL + state = query_params.get("state", [""])[0] + if not state: + raise MissingRequiredArgumentError("state") + + # Get the authorization code from the URL + connect_code = query_params.get("connect_code", [""])[0] + if not connect_code: + raise MissingRequiredArgumentError("connect_code") + + # Retrieve the transaction data using the state + transaction_identifier = f"{self._transaction_identifier}:{state}" + transaction_data = await self._transaction_store.get(transaction_identifier, options=store_options) + + if not transaction_data: + raise MissingTransactionError() + + access_token = await self.get_access_token( + audience=self._my_account_client.audience, + scope="create:me:connected_accounts", + store_options=store_options + ) + + request = CompleteConnectAccountRequest( + auth_session=transaction_data.auth_session, + connect_code=connect_code, + redirect_uri=transaction_data.redirect_uri, + code_verifier=transaction_data.code_verifier + ) + try: + response = await self._my_account_client.complete_connect_account( + access_token=access_token, request=request) + if transaction_data.app_state is not None: + response.app_state = transaction_data.app_state + finally: + # Clean up transaction data + await self._transaction_store.delete(transaction_identifier, options=store_options) + + return response diff --git a/src/auth0_server_python/auth_types/__init__.py b/src/auth0_server_python/auth_types/__init__.py index ce93101..677a7da 100644 --- a/src/auth0_server_python/auth_types/__init__.py +++ b/src/auth0_server_python/auth_types/__init__.py @@ -87,6 +87,8 @@ class TransactionData(BaseModel): audience: Optional[str] = None code_verifier: str app_state: Optional[Any] = None + auth_session: Optional[str] = None + redirect_uri: Optional[str] = None class Config: extra = "allow" # Allow additional fields not defined in the model @@ -210,3 +212,43 @@ class StartLinkUserOptions(BaseModel): connection_scope: Optional[str] = None authorization_params: Optional[dict[str, Any]] = None app_state: Optional[Any] = None + +class ConnectParams(BaseModel): + ticket: str + +class ConnectAccountOptions(BaseModel): + connection: str + redirect_uri: Optional[str] = None + scopes: Optional[list[str]] = None + app_state: Optional[Any] = None + authorization_params: Optional[dict[str, Any]] = None + +class ConnectAccountRequest(BaseModel): + connection: str + scopes: Optional[list[str]] = None + redirect_uri: Optional[str] = None + state: Optional[str] = None + code_challenge: Optional[str] = None + code_challenge_method: Optional[str] = 'S256' + authorization_params: Optional[dict[str, Any]] = None + +class ConnectAccountResponse(BaseModel): + auth_session: str + connect_uri: str + connect_params: ConnectParams + expires_in: int + +class CompleteConnectAccountRequest(BaseModel): + auth_session: str + connect_code: str + redirect_uri: str + code_verifier: Optional[str] = None + +class CompleteConnectAccountResponse(BaseModel): + id: str + connection: str + access_type: str + scopes: list[str] + created_at: str + expires_at: Optional[str] = None + app_state: Optional[Any] = None diff --git a/src/auth0_server_python/error/__init__.py b/src/auth0_server_python/error/__init__.py index 58ce85f..ef181ce 100644 --- a/src/auth0_server_python/error/__init__.py +++ b/src/auth0_server_python/error/__init__.py @@ -56,6 +56,26 @@ def __init__(self, code: str, message: str, interval: Optional[int], cause=None) super().__init__(code, message, cause) self.interval = interval +class MyAccountApiError(Auth0Error): + """ + Error raised when an API request to My Account API fails. + Contains details about the original error from Auth0. + """ + + def __init__( + self, + title: Optional[str], + type: Optional[str], + detail: Optional[str], + status: Optional[int], + validation_errors: Optional[list[dict[str, str]]] = None + ): + super().__init__(detail) + self.title = title + self.type = type + self.detail = detail + self.status = status + self.validation_errors = validation_errors class AccessTokenError(Auth0Error): """Error raised when there's an issue with access tokens.""" @@ -124,6 +144,7 @@ class AccessTokenErrorCode: FAILED_TO_REQUEST_TOKEN = "failed_to_request_token" REFRESH_TOKEN_ERROR = "refresh_token_error" AUTH_REQ_ID_ERROR = "auth_req_id_error" + INCORRECT_AUDIENCE = "incorrect_audience" class AccessTokenForConnectionErrorCode: diff --git a/src/auth0_server_python/tests/test_my_account_client.py b/src/auth0_server_python/tests/test_my_account_client.py new file mode 100644 index 0000000..f4f18fb --- /dev/null +++ b/src/auth0_server_python/tests/test_my_account_client.py @@ -0,0 +1,160 @@ +from unittest.mock import ANY, AsyncMock, MagicMock + +import pytest +from auth0_server_python.auth_server.my_account_client import MyAccountClient +from auth0_server_python.auth_types import ( + CompleteConnectAccountRequest, + CompleteConnectAccountResponse, + ConnectAccountRequest, + ConnectAccountResponse, + ConnectParams, +) +from auth0_server_python.error import MyAccountApiError + + +@pytest.mark.asyncio +async def test_connect_account_success(mocker): + # Arrange + client = MyAccountClient(domain="auth0.local") + response = AsyncMock() + response.status_code = 201 + response.json = MagicMock(return_value={ + "connect_uri": "https://auth0.local/connect", + "auth_session": "", + "connect_params": {"ticket": ""}, + "expires_in": 3600 + }) + + mock_post = mocker.patch("httpx.AsyncClient.post", new_callable=AsyncMock, return_value=response) + request = ConnectAccountRequest( + connection="", + redirect_uri="", + state="", + code_challenge="", + code_challenge_method="S256" + ) + + # Act + result = await client.connect_account(access_token="", request=request) + + # Assert + mock_post.assert_awaited_with( + url="https://auth0.local/me/v1/connected-accounts/connect", + json={ + "connection": "", + "redirect_uri": "", + "state": "", + "code_challenge": "", + "code_challenge_method": "S256", + }, + auth=ANY + ) + assert result == ConnectAccountResponse( + connect_uri="https://auth0.local/connect", + auth_session="", + connect_params=ConnectParams(ticket=""), + expires_in=3600 + ) + +@pytest.mark.asyncio +async def test_connect_account_api_response_failure(mocker): + # Arrange + client = MyAccountClient(domain="auth0.local") + response = AsyncMock() + response.status_code = 401 + response.json = MagicMock(return_value={ + "title": "Invalid Token", + "type": "https://auth0.com/api-errors/A0E-401-0003", + "detail": "Invalid Token", + "status": 401 + }) + + mock_post = mocker.patch("httpx.AsyncClient.post", new_callable=AsyncMock, return_value=response) + request = ConnectAccountRequest( + connection="", + redirect_uri="", + state="", + code_challenge="", + code_challenge_method="S256" + ) + + # Act + + with pytest.raises(MyAccountApiError) as exc: + await client.connect_account(access_token="", request=request) + + # Assert + mock_post.assert_awaited_once() + assert "Invalid Token" in str(exc.value) + + +@pytest.mark.asyncio +async def test_complete_connect_account_success(mocker): + # Arrange + client = MyAccountClient(domain="auth0.local") + response = AsyncMock() + response.status_code = 201 + response.json = MagicMock(return_value={ + "id": "", + "connection": "", + "access_type": "", + "scopes": [""], + "created_at": "", + }) + + mock_post = mocker.patch("httpx.AsyncClient.post", new_callable=AsyncMock, return_value=response) + request = CompleteConnectAccountRequest( + auth_session="", + connect_code="", + redirect_uri="", + ) + + # Act + result = await client.complete_connect_account(access_token="", request=request) + + # Assert + mock_post.assert_awaited_with( + url="https://auth0.local/me/v1/connected-accounts/complete", + json={ + "auth_session": "", + "connect_code": "", + "redirect_uri": "" + }, + auth=ANY + ) + assert result == CompleteConnectAccountResponse( + id="", + connection="", + access_type="", + scopes=[""], + created_at="", + ) + +@pytest.mark.asyncio +async def test_complete_connect_account_api_response_failure(mocker): + # Arrange + client = MyAccountClient(domain="auth0.local") + response = AsyncMock() + response.status_code = 401 + response.json = MagicMock(return_value={ + "title": "Invalid Token", + "type": "https://auth0.com/api-errors/A0E-401-0003", + "detail": "Invalid Token", + "status": 401 + }) + + mock_post = mocker.patch("httpx.AsyncClient.post", new_callable=AsyncMock, return_value=response) + request = CompleteConnectAccountRequest( + auth_session="", + connect_code="", + redirect_uri="", + ) + + # Act + + with pytest.raises(MyAccountApiError) as exc: + await client.complete_connect_account(access_token="", request=request) + + # Assert + mock_post.assert_awaited_once() + assert "Invalid Token" in str(exc.value) diff --git a/src/auth0_server_python/tests/test_server_client.py b/src/auth0_server_python/tests/test_server_client.py index c990513..9f4f2cd 100644 --- a/src/auth0_server_python/tests/test_server_client.py +++ b/src/auth0_server_python/tests/test_server_client.py @@ -1,11 +1,17 @@ import json import time -from unittest.mock import AsyncMock, MagicMock +from unittest.mock import ANY, AsyncMock, MagicMock from urllib.parse import parse_qs, urlparse import pytest +from auth0_server_python.auth_server.my_account_client import MyAccountClient from auth0_server_python.auth_server.server_client import ServerClient from auth0_server_python.auth_types import ( + CompleteConnectAccountRequest, + ConnectAccountOptions, + ConnectAccountRequest, + ConnectAccountResponse, + ConnectParams, LogoutOptions, TransactionData, ) @@ -18,6 +24,7 @@ PollingApiError, StartLinkUserError, ) +from auth0_server_python.utils import PKCE @pytest.mark.asyncio @@ -1596,3 +1603,332 @@ async def test_get_token_by_refresh_token_exchange_failed(mocker): args, kwargs = mock_post.call_args assert kwargs["data"]["refresh_token"] == "" + +@pytest.mark.asyncio +async def test_start_connect_account_calls_connect_and_builds_url(mocker): + # Setup + mock_transaction_store = AsyncMock() + mock_state_store = AsyncMock() + + client = ServerClient( + domain="auth0.local", + client_id="", + client_secret="", + state_store=mock_state_store, + transaction_store=mock_transaction_store, + secret="some-secret" + ) + + mocker.patch.object(client, "get_access_token", AsyncMock(return_value="")) + mock_my_account_client = AsyncMock(MyAccountClient) + mocker.patch.object(client, "_my_account_client", mock_my_account_client) + mock_my_account_client.connect_account.return_value = ConnectAccountResponse( + auth_session="", + connect_uri="http://auth0.local/connected_accounts/connect", + connect_params=ConnectParams( + ticket="ticket123" + ), + expires_in=300 + ) + + mocker.patch.object(PKCE, "generate_random_string", return_value="") + mocker.patch.object(PKCE, "generate_code_verifier", return_value="") + mocker.patch.object(PKCE, "generate_code_challenge", return_value="") + + # Act + url = await client.start_connect_account( + options=ConnectAccountOptions( + connection="", + app_state="", + redirect_uri="/test_redirect_uri" + ) + ) + + # Assert + assert url == "http://auth0.local/connected_accounts/connect?ticket=ticket123" + mock_my_account_client.connect_account.assert_awaited_with( + access_token="", + request=ConnectAccountRequest( + connection="", + redirect_uri="/test_redirect_uri", + code_challenge_method="S256", + code_challenge="", + state= "" + ) + ) + mock_transaction_store.set.assert_awaited_with( + "_a0_tx:", + TransactionData( + code_verifier="", + app_state="", + auth_session="", + redirect_uri="/test_redirect_uri" + ), + options=ANY + ) + +@pytest.mark.asyncio +async def test_start_connect_account_with_scopes(mocker): + # Setup + mock_transaction_store = AsyncMock() + mock_state_store = AsyncMock() + + client = ServerClient( + domain="auth0.local", + client_id="", + client_secret="", + state_store=mock_state_store, + transaction_store=mock_transaction_store, + secret="some-secret" + ) + + mocker.patch.object(client, "get_access_token", AsyncMock(return_value="")) + mock_my_account_client = AsyncMock(MyAccountClient) + mocker.patch.object(client, "_my_account_client", mock_my_account_client) + mock_my_account_client.connect_account.return_value = ConnectAccountResponse( + auth_session="", + connect_uri="http://auth0.local/connected_accounts/connect", + connect_params=ConnectParams( + ticket="ticket123" + ), + expires_in=300 + ) + + # Act + await client.start_connect_account( + options=ConnectAccountOptions( + connection="", + scopes=["scope1", "scope2", "scope3"], + redirect_uri="/test_redirect_uri" + ) + ) + + # Assert + mock_my_account_client.connect_account.assert_awaited() + request = mock_my_account_client.connect_account.mock_calls[0].kwargs["request"] + assert request.scopes == ["scope1", "scope2", "scope3"] + +@pytest.mark.asyncio +async def test_start_connect_account_default_redirect_uri(mocker): + # Setup + mock_transaction_store = AsyncMock() + mock_state_store = AsyncMock() + + client = ServerClient( + domain="auth0.local", + client_id="", + client_secret="", + state_store=mock_state_store, + transaction_store=mock_transaction_store, + secret="some-secret", + redirect_uri="/default_redirect_uri" + ) + + mocker.patch.object(client, "get_access_token", AsyncMock(return_value="")) + mock_my_account_client = AsyncMock(MyAccountClient) + mocker.patch.object(client, "_my_account_client", mock_my_account_client) + mock_my_account_client.connect_account.return_value = ConnectAccountResponse( + auth_session="", + connect_uri="http://auth0.local/connected_accounts/connect", + connect_params=ConnectParams( + ticket="ticket123", + ), + expires_in=300 + ) + + mocker.patch.object(PKCE, "generate_random_string", return_value="") + mocker.patch.object(PKCE, "generate_code_verifier", return_value="") + mocker.patch.object(PKCE, "generate_code_challenge", return_value="") + + # Act + url = await client.start_connect_account( + options=ConnectAccountOptions( + connection="", + app_state="" + ) + ) + + # Assert + assert url == "http://auth0.local/connected_accounts/connect?ticket=ticket123" + mock_my_account_client.connect_account.assert_awaited_with( + access_token="", + request=ConnectAccountRequest( + connection="", + redirect_uri="/default_redirect_uri", + code_challenge_method="S256", + code_challenge="", + state= "" + ) + ) + mock_transaction_store.set.assert_awaited_with( + "_a0_tx:", + TransactionData( + code_verifier="", + app_state="", + auth_session="", + redirect_uri="/default_redirect_uri" + ), + options=ANY + ) + +@pytest.mark.asyncio +async def test_start_connect_account_no_redirect_uri(mocker): + # Setup + mock_transaction_store = AsyncMock() + mock_state_store = AsyncMock() + + client = ServerClient( + domain="auth0.local", + client_id="", + client_secret="", + state_store=mock_state_store, + transaction_store=mock_transaction_store, + secret="some-secret" + ) + + # Act + with pytest.raises(MissingRequiredArgumentError) as exc: + await client.start_connect_account( + options=ConnectAccountOptions( + connection="" + ) + ) + + # Assert + assert "redirect_uri" in str(exc.value) + +@pytest.mark.asyncio +async def test_complete_connect_account_calls_complete(mocker): + # Setup + mock_transaction_store = AsyncMock() + mock_state_store = AsyncMock() + + client = ServerClient( + domain="auth0.local", + client_id="", + client_secret="", + state_store=mock_state_store, + transaction_store=mock_transaction_store, + secret="some-secret", + redirect_uri="/test_redirect_uri" + ) + + mocker.patch.object(client, "get_access_token", AsyncMock(return_value="")) + mock_my_account_client = AsyncMock(MyAccountClient) + mocker.patch.object(client, "_my_account_client", mock_my_account_client) + + mock_transaction_store.get.return_value = TransactionData( + code_verifier="", + app_state="", + auth_session="", + redirect_uri="/test_redirect_uri" + ) + + # Act + await client.complete_connect_account( + url="/test_redirect_uri?connect_code=&state=" + ) + + # Assert + mock_my_account_client.complete_connect_account.assert_awaited_with( + access_token="", + request=CompleteConnectAccountRequest( + auth_session="", + connect_code="", + redirect_uri="/test_redirect_uri", + code_verifier="" + ) + ) + +@pytest.mark.asyncio +async def test_complete_connect_account_no_connect_code(mocker): + # Setup + mock_transaction_store = AsyncMock() + mock_state_store = AsyncMock() + + client = ServerClient( + domain="auth0.local", + client_id="", + client_secret="", + state_store=mock_state_store, + transaction_store=mock_transaction_store, + secret="some-secret", + redirect_uri="/test_redirect_uri" + ) + + mock_my_account_client = AsyncMock(MyAccountClient) + mocker.patch.object(client, "_my_account_client", mock_my_account_client) + + mock_transaction_store.get.return_value = None # no transaction + + # Act + with pytest.raises(MissingRequiredArgumentError) as exc: + await client.complete_connect_account( + url="/test_redirect_uri?state=" + ) + + # Assert + assert "connect_code" in str(exc.value) + mock_my_account_client.complete_connect_account.assert_not_awaited() + +@pytest.mark.asyncio +async def test_complete_connect_account_no_state(mocker): + # Setup + mock_transaction_store = AsyncMock() + mock_state_store = AsyncMock() + + client = ServerClient( + domain="auth0.local", + client_id="", + client_secret="", + state_store=mock_state_store, + transaction_store=mock_transaction_store, + secret="some-secret", + redirect_uri="/test_redirect_uri" + ) + + mock_my_account_client = AsyncMock(MyAccountClient) + mocker.patch.object(client, "_my_account_client", mock_my_account_client) + + mock_transaction_store.get.return_value = None # no transaction + + # Act + with pytest.raises(MissingRequiredArgumentError) as exc: + await client.complete_connect_account( + url="/test_redirect_uri?connect_code=" + ) + + # Assert + assert "state" in str(exc.value) + mock_my_account_client.complete_connect_account.assert_not_awaited() + +@pytest.mark.asyncio +async def test_complete_connect_account_no_transactions(mocker): + # Setup + mock_transaction_store = AsyncMock() + mock_state_store = AsyncMock() + + client = ServerClient( + domain="auth0.local", + client_id="", + client_secret="", + state_store=mock_state_store, + transaction_store=mock_transaction_store, + secret="some-secret", + redirect_uri="/test_redirect_uri" + ) + + mock_my_account_client = AsyncMock(MyAccountClient) + mocker.patch.object(client, "_my_account_client", mock_my_account_client) + + mock_transaction_store.get.return_value = None # no transaction + + # Act + with pytest.raises(MissingTransactionError) as exc: + await client.complete_connect_account( + url="/test_redirect_uri?connect_code=&state=" + ) + + # Assert + assert "transaction" in str(exc.value) + mock_my_account_client.complete_connect_account.assert_not_awaited()