Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion src/bedrock_agentcore/identity/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import logging
import os
from functools import wraps
from typing import Any, Callable, List, Literal, Optional
from typing import Any, Callable, Dict, List, Literal, Optional

import boto3

Expand All @@ -29,6 +29,7 @@ def requires_access_token(
force_authentication: bool = False,
token_poller: Optional[TokenPoller] = None,
custom_state: Optional[str] = None,
custom_parameters: Optional[Dict[str, str]] = None,
) -> Callable:
"""Decorator that fetches an OAuth2 access token before calling the decorated function.

Expand All @@ -42,6 +43,8 @@ def requires_access_token(
force_authentication: Force re-authentication
token_poller: Custom token poller implementation
custom_state: A state that allows applications to verify the validity of callbacks to callback_url
custom_parameters: A map of custom parameters to include in authorization request to the credential provider
Note: these parameters are in addition to standard OAuth 2.0 flow parameters

Returns:
Decorator function
Expand All @@ -62,6 +65,7 @@ async def _get_token() -> str:
force_authentication=force_authentication,
token_poller=token_poller,
custom_state=custom_state,
custom_parameters=custom_parameters,
)

@wraps(func)
Expand Down
5 changes: 5 additions & 0 deletions src/bedrock_agentcore/services/identity.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,7 @@ async def get_token(
force_authentication: bool = False,
token_poller: Optional[TokenPoller] = None,
custom_state: Optional[str] = None,
custom_parameters: Optional[Dict[str, str]] = None,
) -> str:
"""Get an OAuth2 access token for the specified provider.

Expand All @@ -181,6 +182,8 @@ async def get_token(
force_authentication: Force re-authentication even if token exists in the token vault
token_poller: Custom token poller implementation
custom_state: A state that allows applications to verify the validity of callbacks to callback_url
custom_parameters: A map of custom parameters to include in authorization request to the credential provider
Note: these parameters are in addition to standard OAuth 2.0 flow parameters

Returns:
The access token string
Expand All @@ -206,6 +209,8 @@ async def get_token(
req["forceAuthentication"] = force_authentication
if custom_state:
req["customState"] = custom_state
if custom_parameters:
req["customParameters"] = custom_parameters

response = self.dp_client.get_resource_oauth2_token(**req)

Expand Down
44 changes: 44 additions & 0 deletions tests/bedrock_agentcore/identity/test_auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ async def test_async_func(param1, access_token=None):
force_authentication=False,
token_poller=None,
custom_state=None,
custom_parameters=None,
)

def test_sync_function_decoration_no_running_loop(self):
Expand Down Expand Up @@ -167,6 +168,7 @@ async def test_func(param1, my_token=None):
force_authentication=False,
token_poller=None,
custom_state=None,
custom_parameters=None,
)

@pytest.mark.asyncio
Expand Down Expand Up @@ -208,6 +210,7 @@ def on_auth_url(url):
force_authentication=True,
token_poller=mock_poller,
custom_state="myAppState",
custom_parameters=None,
)
async def test_func(token=None):
return f"token={token}"
Expand All @@ -225,6 +228,47 @@ async def test_func(token=None):
force_authentication=True,
token_poller=mock_poller,
custom_state="myAppState",
custom_parameters=None,
)

@pytest.mark.asyncio
async def test_custom_parameters_passed_to_client(self):
with patch("bedrock_agentcore.identity.auth.IdentityClient") as mock_identity_client_class:
mock_client = Mock()
mock_identity_client_class.return_value = mock_client

with patch(
"bedrock_agentcore.identity.auth._get_workload_access_token", new_callable=AsyncMock
) as mock_get_agent_token:
mock_get_agent_token.return_value = "test-agent-token"
mock_client.get_token = AsyncMock(return_value="test-access-token")

with patch("bedrock_agentcore.identity.auth._get_region", return_value="us-west-2"):
custom_params = {"param1": "value1", "param2": "value2"}

@requires_access_token(
provider_name="test-provider",
scopes=["read"],
auth_flow="USER_FEDERATION",
custom_parameters=custom_params,
)
async def test_func(access_token=None):
return access_token

result = await test_func()

assert result == "test-access-token"
mock_client.get_token.assert_called_once_with(
provider_name="test-provider",
agent_identity_token="test-agent-token",
scopes=["read"],
auth_flow="USER_FEDERATION",
callback_url=None,
force_authentication=False,
token_poller=None,
custom_state=None,
on_auth_url=None,
custom_parameters=custom_params,
)


Expand Down
35 changes: 35 additions & 0 deletions tests/bedrock_agentcore/services/test_identity.py
Original file line number Diff line number Diff line change
Expand Up @@ -308,6 +308,41 @@ async def test_get_token_with_custom_token_poller(self):
assert result == expected_token
custom_poller.poll_for_token.assert_called_once()

@pytest.mark.asyncio
async def test_get_token_with_custom_parameters(self):
region = "us-west-2"

with patch("boto3.client") as mock_boto_client:
mock_client = Mock()
mock_boto_client.return_value = mock_client

identity_client = IdentityClient(region)

provider_name = "test-provider"
scopes = ["read", "write"]
agent_identity_token = "test-agent-token"
custom_parameters = {"param1": "value1", "param2": "value2"}
expected_token = "test-access-token"

mock_client.get_resource_oauth2_token.return_value = {"accessToken": expected_token}

result = await identity_client.get_token(
provider_name=provider_name,
scopes=scopes,
agent_identity_token=agent_identity_token,
auth_flow="USER_FEDERATION",
custom_parameters=custom_parameters,
)

assert result == expected_token
mock_client.get_resource_oauth2_token.assert_called_once_with(
resourceCredentialProviderName=provider_name,
scopes=scopes,
oauth2Flow="USER_FEDERATION",
workloadIdentityToken=agent_identity_token,
customParameters=custom_parameters,
)

@pytest.mark.asyncio
async def test_get_api_key_success(self):
"""Test successful API key retrieval."""
Expand Down
Loading