Skip to content

Commit 18b0f77

Browse files
committed
Add support for Identity OAuth 2.0 federation enhancements
1 parent 0027298 commit 18b0f77

File tree

3 files changed

+12
-0
lines changed

3 files changed

+12
-0
lines changed

src/bedrock_agentcore/identity/auth.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ def requires_access_token(
2828
callback_url: Optional[str] = None,
2929
force_authentication: bool = False,
3030
token_poller: Optional[TokenPoller] = None,
31+
session_state: Optional[str] = None,
3132
) -> Callable:
3233
"""Decorator that fetches an OAuth2 access token before calling the decorated function.
3334
@@ -59,6 +60,7 @@ async def _get_token() -> str:
5960
callback_url=callback_url,
6061
force_authentication=force_authentication,
6162
token_poller=token_poller,
63+
session_state=session_state,
6264
)
6365

6466
@wraps(func)

src/bedrock_agentcore/services/identity.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,7 @@ async def get_token(
120120
callback_url: Optional[str] = None,
121121
force_authentication: bool = False,
122122
token_poller: Optional[TokenPoller] = None,
123+
session_state: Optional[str] = None,
123124
) -> str:
124125
"""Get an OAuth2 access token for the specified provider.
125126
@@ -132,6 +133,7 @@ async def get_token(
132133
callback_url: OAuth2 callback URL (must be pre-registered)
133134
force_authentication: Force re-authentication even if token exists in the token vault
134135
token_poller: Custom token poller implementation
136+
session_state: A state that allows applications to verify the validity of call backs to callback_url
135137
136138
Returns:
137139
The access token string
@@ -155,6 +157,8 @@ async def get_token(
155157
req["resourceOauth2ReturnUrl"] = callback_url
156158
if force_authentication:
157159
req["forceAuthentication"] = force_authentication
160+
if session_state:
161+
req["includedState"] = session_state
158162

159163
response = self.dp_client.get_resource_oauth2_token(**req)
160164

@@ -176,6 +180,9 @@ async def get_token(
176180
if force_authentication:
177181
req["forceAuthentication"] = False
178182

183+
if "sessionUri" in response:
184+
req["sessionUri"] = response["sessionUri"]
185+
179186
# Poll for the token
180187
active_poller = token_poller or _DefaultApiTokenPoller(
181188
auth_url, lambda: self.dp_client.get_resource_oauth2_token(**req).get("accessToken", None)

tests/bedrock_agentcore/services/test_identity.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -239,6 +239,7 @@ async def test_get_token_with_optional_parameters(self):
239239
agent_identity_token = "test-agent-token"
240240
callback_url = "https://example.com/callback"
241241
force_authentication = True
242+
session_state = "myAppIncludedState"
242243
expected_token = "test-access-token"
243244

244245
mock_client.get_resource_oauth2_token.return_value = {"accessToken": expected_token}
@@ -250,6 +251,7 @@ async def test_get_token_with_optional_parameters(self):
250251
auth_flow="USER_FEDERATION",
251252
callback_url=callback_url,
252253
force_authentication=force_authentication,
254+
session_state=session_state,
253255
)
254256

255257
assert result == expected_token
@@ -260,6 +262,7 @@ async def test_get_token_with_optional_parameters(self):
260262
workloadIdentityToken=agent_identity_token,
261263
resourceOauth2ReturnUrl=callback_url,
262264
forceAuthentication=force_authentication,
265+
includedState=session_state,
263266
)
264267

265268
@pytest.mark.asyncio

0 commit comments

Comments
 (0)