From 2cbd3a7cdceb22a4e69e0e2312e6957309d408e6 Mon Sep 17 00:00:00 2001 From: Brian Lao Date: Mon, 6 Oct 2025 16:57:01 -0400 Subject: [PATCH] Add support for Identity OAuth 2.0 federation enhancements --- src/bedrock_agentcore/identity/auth.py | 3 +++ src/bedrock_agentcore/services/identity.py | 7 +++++++ tests/bedrock_agentcore/identity/test_auth.py | 4 ++++ tests/bedrock_agentcore/services/test_identity.py | 13 ++++++++++++- 4 files changed, 26 insertions(+), 1 deletion(-) diff --git a/src/bedrock_agentcore/identity/auth.py b/src/bedrock_agentcore/identity/auth.py index ca7d9c0..64b59b5 100644 --- a/src/bedrock_agentcore/identity/auth.py +++ b/src/bedrock_agentcore/identity/auth.py @@ -28,6 +28,7 @@ def requires_access_token( callback_url: Optional[str] = None, force_authentication: bool = False, token_poller: Optional[TokenPoller] = None, + custom_state: Optional[str] = None, ) -> Callable: """Decorator that fetches an OAuth2 access token before calling the decorated function. @@ -40,6 +41,7 @@ def requires_access_token( callback_url: OAuth2 callback URL force_authentication: Force re-authentication token_poller: Custom token poller implementation + custom_state: A state that allows applications to verify the validity of callbacks to callback_url Returns: Decorator function @@ -59,6 +61,7 @@ async def _get_token() -> str: callback_url=callback_url, force_authentication=force_authentication, token_poller=token_poller, + custom_state=custom_state, ) @wraps(func) diff --git a/src/bedrock_agentcore/services/identity.py b/src/bedrock_agentcore/services/identity.py index ffafc3f..402a1b6 100644 --- a/src/bedrock_agentcore/services/identity.py +++ b/src/bedrock_agentcore/services/identity.py @@ -120,6 +120,7 @@ async def get_token( callback_url: Optional[str] = None, force_authentication: bool = False, token_poller: Optional[TokenPoller] = None, + custom_state: Optional[str] = None, ) -> str: """Get an OAuth2 access token for the specified provider. @@ -132,6 +133,7 @@ async def get_token( callback_url: OAuth2 callback URL (must be pre-registered) force_authentication: Force re-authentication even if token exists in the token vault token_poller: Custom token poller implementation + custom_state: A state that allows applications to verify the validity of callbacks to callback_url Returns: The access token string @@ -155,6 +157,8 @@ async def get_token( req["resourceOauth2ReturnUrl"] = callback_url if force_authentication: req["forceAuthentication"] = force_authentication + if custom_state: + req["customState"] = custom_state response = self.dp_client.get_resource_oauth2_token(**req) @@ -176,6 +180,9 @@ async def get_token( if force_authentication: req["forceAuthentication"] = False + if "sessionUri" in response: + req["sessionUri"] = response["sessionUri"] + # Poll for the token active_poller = token_poller or _DefaultApiTokenPoller( auth_url, lambda: self.dp_client.get_resource_oauth2_token(**req).get("accessToken", None) diff --git a/tests/bedrock_agentcore/identity/test_auth.py b/tests/bedrock_agentcore/identity/test_auth.py index 34c0777..bd06466 100644 --- a/tests/bedrock_agentcore/identity/test_auth.py +++ b/tests/bedrock_agentcore/identity/test_auth.py @@ -54,6 +54,7 @@ async def test_async_func(param1, access_token=None): callback_url=None, force_authentication=False, token_poller=None, + custom_state=None, ) def test_sync_function_decoration_no_running_loop(self): @@ -165,6 +166,7 @@ async def test_func(param1, my_token=None): callback_url=None, force_authentication=False, token_poller=None, + custom_state=None, ) @pytest.mark.asyncio @@ -205,6 +207,7 @@ def on_auth_url(url): callback_url="https://example.com/callback", force_authentication=True, token_poller=mock_poller, + custom_state="myAppState", ) async def test_func(token=None): return f"token={token}" @@ -221,6 +224,7 @@ async def test_func(token=None): callback_url="https://example.com/callback", force_authentication=True, token_poller=mock_poller, + custom_state="myAppState", ) diff --git a/tests/bedrock_agentcore/services/test_identity.py b/tests/bedrock_agentcore/services/test_identity.py index 4c0b723..ffba93d 100644 --- a/tests/bedrock_agentcore/services/test_identity.py +++ b/tests/bedrock_agentcore/services/test_identity.py @@ -117,11 +117,13 @@ async def test_get_token_with_auth_url_polling(self): agent_identity_token = "test-agent-token" auth_url = "https://example.com/auth" expected_token = "test-access-token" + session_uri = "https://example-federation-authorization-request/12345" # First call returns auth URL, subsequent calls return token mock_client.get_resource_oauth2_token.side_effect = [ {"authorizationUrl": auth_url}, {"accessToken": expected_token}, + {"sessionUri": session_uri}, ] # Mock the token poller @@ -239,6 +241,7 @@ async def test_get_token_with_optional_parameters(self): agent_identity_token = "test-agent-token" callback_url = "https://example.com/callback" force_authentication = True + custom_state = "myAppCustomState" expected_token = "test-access-token" mock_client.get_resource_oauth2_token.return_value = {"accessToken": expected_token} @@ -250,6 +253,7 @@ async def test_get_token_with_optional_parameters(self): auth_flow="USER_FEDERATION", callback_url=callback_url, force_authentication=force_authentication, + custom_state=custom_state, ) assert result == expected_token @@ -260,6 +264,7 @@ async def test_get_token_with_optional_parameters(self): workloadIdentityToken=agent_identity_token, resourceOauth2ReturnUrl=callback_url, forceAuthentication=force_authentication, + customState=custom_state, ) @pytest.mark.asyncio @@ -278,8 +283,13 @@ async def test_get_token_with_custom_token_poller(self): agent_identity_token = "test-agent-token" auth_url = "https://example.com/auth" expected_token = "test-access-token" + force_authentication = True + session_uri = "https://example-federation-authorization-request/12345" - mock_client.get_resource_oauth2_token.return_value = {"authorizationUrl": auth_url} + mock_client.get_resource_oauth2_token.return_value = { + "authorizationUrl": auth_url, + "sessionUri": session_uri, + } # Mock custom token poller custom_poller = Mock() @@ -290,6 +300,7 @@ async def test_get_token_with_custom_token_poller(self): agent_identity_token=agent_identity_token, auth_flow="USER_FEDERATION", token_poller=custom_poller, + force_authentication=force_authentication, ) assert result == expected_token