diff --git a/src/bedrock_agentcore/identity/auth.py b/src/bedrock_agentcore/identity/auth.py index b763775..c8b08ba 100644 --- a/src/bedrock_agentcore/identity/auth.py +++ b/src/bedrock_agentcore/identity/auth.py @@ -5,7 +5,7 @@ import logging import os from functools import wraps -from typing import Any, Callable, List, Literal, Optional +from typing import Any, Callable, Dict, List, Literal, Optional import boto3 @@ -29,6 +29,7 @@ def requires_access_token( force_authentication: bool = False, token_poller: Optional[TokenPoller] = None, custom_state: Optional[str] = None, + custom_parameters: Optional[Dict[str, str]] = None, ) -> Callable: """Decorator that fetches an OAuth2 access token before calling the decorated function. @@ -42,6 +43,8 @@ def requires_access_token( force_authentication: Force re-authentication token_poller: Custom token poller implementation custom_state: A state that allows applications to verify the validity of callbacks to callback_url + custom_parameters: A map of custom parameters to include in authorization request to the credential provider + Note: these parameters are in addition to standard OAuth 2.0 flow parameters Returns: Decorator function @@ -62,6 +65,7 @@ async def _get_token() -> str: force_authentication=force_authentication, token_poller=token_poller, custom_state=custom_state, + custom_parameters=custom_parameters, ) @wraps(func) diff --git a/src/bedrock_agentcore/services/identity.py b/src/bedrock_agentcore/services/identity.py index f8d0e86..80d30ac 100644 --- a/src/bedrock_agentcore/services/identity.py +++ b/src/bedrock_agentcore/services/identity.py @@ -168,6 +168,7 @@ async def get_token( force_authentication: bool = False, token_poller: Optional[TokenPoller] = None, custom_state: Optional[str] = None, + custom_parameters: Optional[Dict[str, str]] = None, ) -> str: """Get an OAuth2 access token for the specified provider. @@ -181,6 +182,8 @@ async def get_token( force_authentication: Force re-authentication even if token exists in the token vault token_poller: Custom token poller implementation custom_state: A state that allows applications to verify the validity of callbacks to callback_url + custom_parameters: A map of custom parameters to include in authorization request to the credential provider + Note: these parameters are in addition to standard OAuth 2.0 flow parameters Returns: The access token string @@ -206,6 +209,8 @@ async def get_token( req["forceAuthentication"] = force_authentication if custom_state: req["customState"] = custom_state + if custom_parameters: + req["customParameters"] = custom_parameters response = self.dp_client.get_resource_oauth2_token(**req) diff --git a/tests/bedrock_agentcore/identity/test_auth.py b/tests/bedrock_agentcore/identity/test_auth.py index 2629bb1..0e520b2 100644 --- a/tests/bedrock_agentcore/identity/test_auth.py +++ b/tests/bedrock_agentcore/identity/test_auth.py @@ -55,6 +55,7 @@ async def test_async_func(param1, access_token=None): force_authentication=False, token_poller=None, custom_state=None, + custom_parameters=None, ) def test_sync_function_decoration_no_running_loop(self): @@ -167,6 +168,7 @@ async def test_func(param1, my_token=None): force_authentication=False, token_poller=None, custom_state=None, + custom_parameters=None, ) @pytest.mark.asyncio @@ -208,6 +210,7 @@ def on_auth_url(url): force_authentication=True, token_poller=mock_poller, custom_state="myAppState", + custom_parameters=None, ) async def test_func(token=None): return f"token={token}" @@ -225,6 +228,47 @@ async def test_func(token=None): force_authentication=True, token_poller=mock_poller, custom_state="myAppState", + custom_parameters=None, + ) + + @pytest.mark.asyncio + async def test_custom_parameters_passed_to_client(self): + with patch("bedrock_agentcore.identity.auth.IdentityClient") as mock_identity_client_class: + mock_client = Mock() + mock_identity_client_class.return_value = mock_client + + with patch( + "bedrock_agentcore.identity.auth._get_workload_access_token", new_callable=AsyncMock + ) as mock_get_agent_token: + mock_get_agent_token.return_value = "test-agent-token" + mock_client.get_token = AsyncMock(return_value="test-access-token") + + with patch("bedrock_agentcore.identity.auth._get_region", return_value="us-west-2"): + custom_params = {"param1": "value1", "param2": "value2"} + + @requires_access_token( + provider_name="test-provider", + scopes=["read"], + auth_flow="USER_FEDERATION", + custom_parameters=custom_params, + ) + async def test_func(access_token=None): + return access_token + + result = await test_func() + + assert result == "test-access-token" + mock_client.get_token.assert_called_once_with( + provider_name="test-provider", + agent_identity_token="test-agent-token", + scopes=["read"], + auth_flow="USER_FEDERATION", + callback_url=None, + force_authentication=False, + token_poller=None, + custom_state=None, + on_auth_url=None, + custom_parameters=custom_params, ) diff --git a/tests/bedrock_agentcore/services/test_identity.py b/tests/bedrock_agentcore/services/test_identity.py index 477d3e1..d3801c1 100644 --- a/tests/bedrock_agentcore/services/test_identity.py +++ b/tests/bedrock_agentcore/services/test_identity.py @@ -308,6 +308,41 @@ async def test_get_token_with_custom_token_poller(self): assert result == expected_token custom_poller.poll_for_token.assert_called_once() + @pytest.mark.asyncio + async def test_get_token_with_custom_parameters(self): + region = "us-west-2" + + with patch("boto3.client") as mock_boto_client: + mock_client = Mock() + mock_boto_client.return_value = mock_client + + identity_client = IdentityClient(region) + + provider_name = "test-provider" + scopes = ["read", "write"] + agent_identity_token = "test-agent-token" + custom_parameters = {"param1": "value1", "param2": "value2"} + expected_token = "test-access-token" + + mock_client.get_resource_oauth2_token.return_value = {"accessToken": expected_token} + + result = await identity_client.get_token( + provider_name=provider_name, + scopes=scopes, + agent_identity_token=agent_identity_token, + auth_flow="USER_FEDERATION", + custom_parameters=custom_parameters, + ) + + assert result == expected_token + mock_client.get_resource_oauth2_token.assert_called_once_with( + resourceCredentialProviderName=provider_name, + scopes=scopes, + oauth2Flow="USER_FEDERATION", + workloadIdentityToken=agent_identity_token, + customParameters=custom_parameters, + ) + @pytest.mark.asyncio async def test_get_api_key_success(self): """Test successful API key retrieval."""