diff --git a/examples/RetrievingData.md b/examples/RetrievingData.md index 29290f0..88fb610 100644 --- a/examples/RetrievingData.md +++ b/examples/RetrievingData.md @@ -70,6 +70,107 @@ access_token = await server_client.get_access_token(store_options=store_options) Read more above in [Configuring the Store](./ConfigureStore.md). +## Multi-Resource Refresh Tokens (MRRT) + +Multi-Resource Refresh Tokens allow using a single refresh token to obtain access tokens for multiple audiences, simplifying token management in applications that interact with multiple backend services. + +Read more about [Multi-Resource Refresh Tokens in the Auth0 documentation](https://auth0.com/docs/secure/tokens/refresh-tokens/multi-resource-refresh-token). + + +> [!WARNING] +> When using Multi-Resource Refresh Token Configuration (MRRT), **Refresh Token Policies** on your Application need to be configured with the audiences you want to support. See the [Auth0 MRRT documentation](https://auth0.com/docs/secure/tokens/refresh-tokens/multi-resource-refresh-token) for setup instructions. +> +> **Tokens requested for audiences outside your configured policies will be ignored by Auth0, which will return a token for the default audience instead!** + +### Configuring Scopes Per Audience + +When working with multiple APIs, you can define different default scopes for each audience by passing an object instead of a string. This is particularly useful when different APIs require different default scopes: + +```python +server_client = ServerClient( + ... + authorization_params={ + "audience": "https://api.example.com", # Default audience + "scope": { + "https://api.example.com": "openid profile email offline_access read:products read:orders", + "https://analytics.example.com": "openid profile email offline_access read:analytics write:analytics", + "https://admin.example.com": "openid profile email offline_access read:admin write:admin delete:admin" + } + } +) +``` + +**How it works:** + +- Each key in the `scope` object is an `audience` identifier +- The corresponding value is the scope string for that audience +- When calling `get_access_token(audience=audience)`, the SDK automatically uses the configured scopes for that audience. When scopes are also passed in the method call, they are be merged with the default scopes for that audience. + +### Usage Example + +To retrieve access tokens for different audiences, use the `get_access_token()` method with an `audience` (and optionally also the `scope`) parameter. + +```python + +server_client = ServerClient( + ... + authorization_params={ + "audience": "https://api.example.com", # Default audience + "scope": { + "https://api.example.com": "openid email profile", + "https://analytics.example.com": "read:analytics write:analytics" + } + } +) + +# Get token for default audience +default_token = await server_client.get_access_token() +# returns token for https://api.example.com with openid, email, and profile scopes + + # Get token for different audience +data_token = await server_client.get_access_token(audience="https://analytics.example.com") +# returns token for https://analytics.example.com with read:analytics and write:analytics scopes + +# Get token with additional scopes +admin_token = await server_client.get_access_token( + audience="https://api.example.com", + scope="write:admin" +) +# returns token for https://api.example.com with openid, email, profile and write:admin scopes + +``` + +### Token Management Best Practices + +**Configure Broad Default Scopes**: Define comprehensive scopes in your `ServerClient` constructor for common use cases. This minimizes the need to request additional scopes dynamically, reducing the amount of tokens that need to be stored. + +```python +server_client = ServerClient( + ... + authorization_params={ + "audience": "https://api.example.com", # Default audience + # Configure broad default scopes for most common operations + "scope": { + "https://api.example.com": "openid profile email offline_access read:products read:orders read:users" + } + } +) +``` + +**Minimize Dynamic Scope Requests**: Avoid passing `scope` when calling `get_access_token()` unless absolutely necessary. Each `audience` + `scope` combination results in a token to store in the session, increasing session size. + +```python +# Preferred: Use default scopes +token = await server_client.get_access_token(audience="https://api.example.com") + + +# Avoid unless necessary: Dynamic scopes increase session size +token = await server_client.get_access_token( + audience="https://api.example.com" + scope="openid profile email read:products write:products admin:all" +) +``` + ## Retrieving an Access Token for a Connections The SDK's `get_access_token_for_connection()` can be used to retrieve an Access Token for a connection (e.g. `google-oauth2`) for the current logged-in user: diff --git a/src/auth0_server_python/auth_server/server_client.py b/src/auth0_server_python/auth_server/server_client.py index 204f749..e66680f 100644 --- a/src/auth0_server_python/auth_server/server_client.py +++ b/src/auth0_server_python/auth_server/server_client.py @@ -40,7 +40,7 @@ # Generic type for store options TStoreOptions = TypeVar('TStoreOptions') INTERNAL_AUTHORIZE_PARAMS = ["client_id", "redirect_uri", "response_type", - "code_challenge", "code_challenge_method", "state", "nonce"] + "code_challenge", "code_challenge_method", "state", "nonce", "scope"] class ServerClient(Generic[TStoreOptions]): @@ -48,6 +48,7 @@ class ServerClient(Generic[TStoreOptions]): Main client for Auth0 server SDK. Handles authentication flows, session management, and token operations using Authlib for OIDC functionality. """ + DEFAULT_AUDIENCE_STATE_KEY = "default" def __init__( self, @@ -77,6 +78,7 @@ def __init__( transaction_identifier: Identifier for transaction data state_identifier: Identifier for state data authorization_params: Default parameters for authorization requests + pushed_authorization_requests: Whether to use PAR for authorization requests """ if not secret: raise MissingRequiredArgumentError("secret") @@ -152,10 +154,17 @@ async def start_interactive_login( state = PKCE.generate_random_string(32) auth_params["state"] = state + #merge any requested scope with defaults + requested_scope = options.authorization_params.get("scope", None) if options.authorization_params else None + audience = auth_params.get("audience", None) + merged_scope = self._merge_scope_with_defaults(requested_scope, audience) + auth_params["scope"] = merged_scope + # Build the transaction data to store transaction_data = TransactionData( code_verifier=code_verifier, - app_state=options.app_state + app_state=options.app_state, + audience=audience, ) # Store the transaction data @@ -290,7 +299,7 @@ async def complete_interactive_login( # Build a token set using the token response data token_set = TokenSet( - audience=token_response.get("audience", "default"), + audience=transaction_data.audience or self.DEFAULT_AUDIENCE_STATE_KEY, access_token=token_response.get("access_token", ""), scope=token_response.get("scope", ""), expires_at=int(time.time()) + @@ -509,7 +518,7 @@ async def login_backchannel( existing_state_data = await self._state_store.get(self._state_identifier, store_options) audience = self._default_authorization_params.get( - "audience", "default") + "audience", self.DEFAULT_AUDIENCE_STATE_KEY) state_data = State.update_state_data( audience, @@ -562,7 +571,12 @@ async def get_session(self, store_options: Optional[dict[str, Any]] = None) -> O return session_data return None - async def get_access_token(self, store_options: Optional[dict[str, Any]] = None) -> str: + async def get_access_token( + self, + store_options: Optional[dict[str, Any]] = None, + audience: Optional[str] = None, + scope: Optional[str] = None, + ) -> str: """ Retrieves the access token from the store, or calls Auth0 when the access token is expired and a refresh token is available in the store. @@ -579,10 +593,13 @@ async def get_access_token(self, store_options: Optional[dict[str, Any]] = None) """ state_data = await self._state_store.get(self._state_identifier, store_options) - # Get audience and scope from options or use defaults auth_params = self._default_authorization_params or {} - audience = auth_params.get("audience", "default") - scope = auth_params.get("scope") + + # Get audience passed in on options or use defaults + if not audience: + audience = auth_params.get("audience", None) + + merged_scope = self._merge_scope_with_defaults(scope, audience) if state_data and hasattr(state_data, "dict") and callable(state_data.dict): state_data_dict = state_data.dict() @@ -592,10 +609,7 @@ async def get_access_token(self, store_options: Optional[dict[str, Any]] = None) # Find matching token set token_set = None if state_data_dict and "token_sets" in state_data_dict: - for ts in state_data_dict["token_sets"]: - if ts.get("audience") == audience and (not scope or ts.get("scope") == scope): - token_set = ts - break + token_set = self._find_matching_token_set(state_data_dict["token_sets"], audience, merged_scope) # If token is valid, return it if token_set and token_set.get("expires_at", 0) > time.time(): @@ -610,9 +624,14 @@ async def get_access_token(self, store_options: Optional[dict[str, Any]] = None) # Get new token with refresh token try: - token_endpoint_response = await self.get_token_by_refresh_token({ - "refresh_token": state_data_dict["refresh_token"] - }) + get_refresh_token_options = {"refresh_token": state_data_dict["refresh_token"]} + if audience: + get_refresh_token_options["audience"] = audience + + if merged_scope: + get_refresh_token_options["scope"] = merged_scope + + token_endpoint_response = await self.get_token_by_refresh_token(get_refresh_token_options) # Update state data with new token existing_state_data = await self._state_store.get(self._state_identifier, store_options) @@ -631,6 +650,51 @@ async def get_access_token(self, store_options: Optional[dict[str, Any]] = None) f"Failed to get token with refresh token: {str(e)}" ) + def _merge_scope_with_defaults( + self, + request_scope: Optional[str], + audience: Optional[str] + ) -> Optional[str]: + audience = audience or self.DEFAULT_AUDIENCE_STATE_KEY + default_scopes = "" + if self._default_authorization_params and "scope" in self._default_authorization_params: + auth_param_scope = self._default_authorization_params.get("scope") + # For backwards compatibility, allow scope to be a single string + # or dictionary by audience for MRRT + if isinstance(auth_param_scope, dict) and audience in auth_param_scope: + default_scopes = auth_param_scope[audience] + elif isinstance(auth_param_scope, str): + default_scopes = auth_param_scope + + default_scopes_list = default_scopes.split() + request_scopes_list = (request_scope or "").split() + + merged_scopes = list(dict.fromkeys(default_scopes_list + request_scopes_list)) + return " ".join(merged_scopes) if merged_scopes else None + + + def _find_matching_token_set( + self, + token_sets: list[dict[str, Any]], + audience: Optional[str], + scope: Optional[str] + ) -> Optional[dict[str, Any]]: + audience = audience or self.DEFAULT_AUDIENCE_STATE_KEY + requested_scopes = set(scope.split()) if scope else set() + matches: list[tuple[int, dict]] = [] + for token_set in token_sets: + token_set_audience = token_set.get("audience") + token_set_scopes = set(token_set.get("scope", "").split()) + if token_set_audience == audience and token_set_scopes == requested_scopes: + # short-circuit if exact match + return token_set + if token_set_audience == audience and token_set_scopes.issuperset(requested_scopes): + # consider stored tokens with more scopes than requested by number of scopes + matches.append((len(token_set_scopes), token_set)) + + # Return the token set with the smallest superset of scopes that matches the requested audience and scopes + return min(matches, key=lambda t: t[0])[1] if matches else None + async def get_access_token_for_connection( self, options: dict[str, Any], @@ -1143,9 +1207,18 @@ async def get_token_by_refresh_token(self, options: dict[str, Any]) -> dict[str, "client_id": self._client_id, } - # Add scope if present in the original authorization params - if "scope" in self._default_authorization_params: - token_params["scope"] = self._default_authorization_params["scope"] + audience = options.get("audience") + if audience: + token_params["audience"] = audience + + # Merge scope if present in options with any in the original authorization params + merged_scope = self._merge_scope_with_defaults( + request_scope=options.get("scope"), + audience=audience + ) + + if merged_scope: + token_params["scope"] = merged_scope # Exchange the refresh token for an access token async with httpx.AsyncClient() as client: diff --git a/src/auth0_server_python/tests/test_server_client.py b/src/auth0_server_python/tests/test_server_client.py index 9328a34..c990513 100644 --- a/src/auth0_server_python/tests/test_server_client.py +++ b/src/auth0_server_python/tests/test_server_client.py @@ -5,7 +5,10 @@ import pytest from auth0_server_python.auth_server.server_client import ServerClient -from auth0_server_python.auth_types import LogoutOptions, TransactionData +from auth0_server_python.auth_types import ( + LogoutOptions, + TransactionData, +) from auth0_server_python.error import ( AccessTokenForConnectionError, ApiError, @@ -81,7 +84,6 @@ async def test_start_interactive_login_builds_auth_url(mocker): mock_transaction_store.set.assert_awaited() mock_oauth.assert_called_once() - @pytest.mark.asyncio async def test_complete_interactive_login_no_transaction(): mock_transaction_store = AsyncMock() @@ -382,8 +384,7 @@ async def test_get_access_token_refresh_expired(mocker): secret="some-secret" ) - # Patch method that does the refresh call - mocker.patch.object(client, "get_token_by_refresh_token", return_value={ + get_refresh_token_mock = mocker.patch.object(client, "get_token_by_refresh_token", return_value={ "access_token": "new_token", "expires_in": 3600 }) @@ -391,6 +392,350 @@ async def test_get_access_token_refresh_expired(mocker): token = await client.get_access_token() assert token == "new_token" mock_state_store.set.assert_awaited_once() + get_refresh_token_mock.assert_awaited_with({ + "refresh_token": "refresh_xyz" + }) + +@pytest.mark.asyncio +async def test_get_access_token_refresh_merging_default_scope(mocker): + mock_state_store = AsyncMock() + # expired token + mock_state_store.get.return_value = { + "refresh_token": "refresh_xyz", + "token_sets": [ + { + "audience": "default", + "access_token": "expired_token", + "expires_at": int(time.time()) - 500 + } + ] + } + + client = ServerClient( + domain="auth0.local", + client_id="client_id", + client_secret="client_secret", + transaction_store=AsyncMock(), + state_store=mock_state_store, + secret="some-secret", + authorization_params= { + "audience": "default", + "scope": "openid profile email" + } + ) + + get_refresh_token_mock = mocker.patch.object(client, "get_token_by_refresh_token", return_value={ + "access_token": "new_token", + "expires_in": 3600 + }) + + token = await client.get_access_token(scope="foo:bar") + assert token == "new_token" + mock_state_store.set.assert_awaited_once() + get_refresh_token_mock.assert_awaited_with({ + "refresh_token": "refresh_xyz", + "audience": "default", + "scope": "openid profile email foo:bar" + }) + +@pytest.mark.asyncio +async def test_get_access_token_refresh_with_auth_params_scope(mocker): + mock_state_store = AsyncMock() + # expired token + mock_state_store.get.return_value = { + "refresh_token": "refresh_xyz", + "token_sets": [ + { + "audience": "default", + "access_token": "expired_token", + "expires_at": int(time.time()) - 500 + } + ] + } + + client = ServerClient( + domain="auth0.local", + client_id="client_id", + client_secret="client_secret", + transaction_store=AsyncMock(), + state_store=mock_state_store, + secret="some-secret", + authorization_params= { + "scope": "openid profile email" + } + ) + + get_refresh_token_mock = mocker.patch.object(client, "get_token_by_refresh_token", return_value={ + "access_token": "new_token", + "expires_in": 3600 + }) + + token = await client.get_access_token() + assert token == "new_token" + mock_state_store.set.assert_awaited_once() + get_refresh_token_mock.assert_awaited_with({ + "refresh_token": "refresh_xyz", + "scope": "openid profile email" + }) + +@pytest.mark.asyncio +async def test_get_access_token_refresh_with_auth_params_audience(mocker): + mock_state_store = AsyncMock() + # expired token + mock_state_store.get.return_value = { + "refresh_token": "refresh_xyz", + "token_sets": [ + { + "audience": "my_audience", + "access_token": "expired_token", + "expires_at": int(time.time()) - 500 + } + ] + } + + client = ServerClient( + domain="auth0.local", + client_id="client_id", + client_secret="client_secret", + transaction_store=AsyncMock(), + state_store=mock_state_store, + secret="some-secret", + authorization_params= { + "audience": "my_audience" + } + ) + + get_refresh_token_mock = mocker.patch.object(client, "get_token_by_refresh_token", return_value={ + "access_token": "new_token", + "expires_in": 3600 + }) + + token = await client.get_access_token() + assert token == "new_token" + mock_state_store.set.assert_awaited_once() + get_refresh_token_mock.assert_awaited_with({ + "refresh_token": "refresh_xyz", + "audience": "my_audience" + }) + +@pytest.mark.asyncio +async def test_get_access_token_mrrt(mocker): + mock_state_store = AsyncMock() + # expired token + mock_state_store.get.return_value = { + "refresh_token": "refresh_xyz", + "token_sets": [ + { + "audience": "default", + "access_token": "valid_token_for_other_audience", + "expires_at": int(time.time()) + 500 + } + ] + } + + client = ServerClient( + domain="auth0.local", + client_id="client_id", + client_secret="client_secret", + transaction_store=AsyncMock(), + state_store=mock_state_store, + secret="some-secret" + ) + + # Patch method that does the refresh call + get_refresh_token_mock = mocker.patch.object(client, "get_token_by_refresh_token", return_value={ + "access_token": "new_token", + "expires_in": 3600 + }) + + token = await client.get_access_token( + audience="some_audience", + scope="foo:bar" + ) + + assert token == "new_token" + mock_state_store.set.assert_awaited_once() + args, kwargs = mock_state_store.set.call_args + stored_state = args[1] + assert "token_sets" in stored_state + assert len(stored_state["token_sets"]) == 2 + get_refresh_token_mock.assert_awaited_with({ + "refresh_token": "refresh_xyz", + "audience": "some_audience", + "scope": "foo:bar", + }) + +@pytest.mark.asyncio +async def test_get_access_token_mrrt_with_auth_params_scope(mocker): + mock_state_store = AsyncMock() + # expired token + mock_state_store.get.return_value = { + "refresh_token": "refresh_xyz", + "token_sets": [ + { + "audience": "default", + "access_token": "valid_token_for_other_audience", + "expires_at": int(time.time()) + 500 + } + ] + } + + client = ServerClient( + domain="auth0.local", + client_id="client_id", + client_secret="client_secret", + transaction_store=AsyncMock(), + state_store=mock_state_store, + secret="some-secret", + authorization_params= { + "audience": "default", + "scope": { + "default": "openid profile email foo:bar", + "some_audience": "foo:bar" + } + } + ) + + # Patch method that does the refresh call + get_refresh_token_mock = mocker.patch.object(client, "get_token_by_refresh_token", return_value={ + "access_token": "new_token", + "expires_in": 3600 + }) + + token = await client.get_access_token( + audience="some_audience" + ) + + assert token == "new_token" + mock_state_store.set.assert_awaited_once() + args, kwargs = mock_state_store.set.call_args + stored_state = args[1] + assert "token_sets" in stored_state + assert len(stored_state["token_sets"]) == 2 + get_refresh_token_mock.assert_awaited_with({ + "refresh_token": "refresh_xyz", + "audience": "some_audience", + "scope": "foo:bar", + }) + +@pytest.mark.asyncio +async def test_get_access_token_from_store_with_multiple_audiences(mocker): + mock_state_store = AsyncMock() + mock_state_store.get.return_value = { + "refresh_token": None, + "token_sets": [ + { + "audience": "default", + "access_token": "token_from_store", + "expires_at": int(time.time()) + 500 + }, + { + "audience": "some_audience", + "access_token": "other_token_from_store", + "scope": "foo:bar", + "expires_at": int(time.time()) + 500 + } + ] + } + + client = ServerClient( + domain="auth0.local", + client_id="client_id", + client_secret="client_secret", + transaction_store=AsyncMock(), + state_store=mock_state_store, + secret="some-secret" + ) + + get_refresh_token_mock = mocker.patch.object(client, "get_token_by_refresh_token") + + token = await client.get_access_token( + audience="some_audience", + scope="foo:bar" + ) + + assert token == "other_token_from_store" + get_refresh_token_mock.assert_not_awaited() + +@pytest.mark.asyncio +async def test_get_access_token_from_store_with_a_superset_of_requested_scopes(mocker): + mock_state_store = AsyncMock() + mock_state_store.get.return_value = { + "refresh_token": None, + "token_sets": [ + { + "audience": "default", + "access_token": "token_from_store", + "expires_at": int(time.time()) + 500 + }, + { + "audience": "some_audience", + "access_token": "other_token_from_store", + "scope": "read:foo write:foo read:bar write:bar", + "expires_at": int(time.time()) + 500 + } + ] + } + + client = ServerClient( + domain="auth0.local", + client_id="client_id", + client_secret="client_secret", + transaction_store=AsyncMock(), + state_store=mock_state_store, + secret="some-secret" + ) + + get_refresh_token_mock = mocker.patch.object(client, "get_token_by_refresh_token") + + token = await client.get_access_token( + audience="some_audience", + scope="read:foo read:bar" + ) + + assert token == "other_token_from_store" + get_refresh_token_mock.assert_not_awaited() + + +@pytest.mark.asyncio +async def test_get_access_token_from_store_returns_minimum_matching_scopes(mocker): + mock_state_store = AsyncMock() + mock_state_store.get.return_value = { + "refresh_token": None, + "token_sets": [ + { + "audience": "some_audience", + "access_token": "maximum_scope_token", + "scope": "read:foo write:foo read:bar write:bar admin:all", + "expires_at": int(time.time()) + 500 + }, + { + "audience": "some_audience", + "access_token": "minimum_scope_token", + "scope": "read:foo write:foo read:bar write:bar", + "expires_at": int(time.time()) + 500 + } + ] + } + + client = ServerClient( + domain="auth0.local", + client_id="client_id", + client_secret="client_secret", + transaction_store=AsyncMock(), + state_store=mock_state_store, + secret="some-secret" + ) + + get_refresh_token_mock = mocker.patch.object(client, "get_token_by_refresh_token") + + token = await client.get_access_token( + audience="some_audience", + scope="read:foo read:bar" + ) + + assert token == "minimum_scope_token" + get_refresh_token_mock.assert_not_awaited() @pytest.mark.asyncio async def test_get_access_token_for_connection_cached(): @@ -1251,4 +1596,3 @@ async def test_get_token_by_refresh_token_exchange_failed(mocker): args, kwargs = mock_post.call_args assert kwargs["data"]["refresh_token"] == "" -