diff --git a/src/bedrock_agentcore/identity/auth.py b/src/bedrock_agentcore/identity/auth.py index ca7d9c0..1cad8a1 100644 --- a/src/bedrock_agentcore/identity/auth.py +++ b/src/bedrock_agentcore/identity/auth.py @@ -26,6 +26,7 @@ def requires_access_token( on_auth_url: Optional[Callable[[str], Any]] = None, auth_flow: Literal["M2M", "USER_FEDERATION"], callback_url: Optional[str] = None, + custom_parameters: Optional[dict] = {}, force_authentication: bool = False, token_poller: Optional[TokenPoller] = None, ) -> Callable: @@ -38,6 +39,7 @@ def requires_access_token( on_auth_url: Callback for handling authorization URLs auth_flow: Authentication flow type ("M2M" or "USER_FEDERATION") callback_url: OAuth2 callback URL + custom_parameters: optional parameters to be sent to the authorizer endpoint of the provider force_authentication: Force re-authentication token_poller: Custom token poller implementation @@ -57,6 +59,7 @@ async def _get_token() -> str: on_auth_url=on_auth_url, auth_flow=auth_flow, callback_url=callback_url, + custom_parameters=custom_parameters, force_authentication=force_authentication, token_poller=token_poller, ) @@ -173,7 +176,9 @@ async def _set_up_local_auth(client: IdentityClient) -> str: workload_identity_name = config.get("workload_identity_name") if workload_identity_name: - print(f"Found existing workload identity from {config_path.absolute()}: {workload_identity_name}") + print( + f"Found existing workload identity from {config_path.absolute()}: {workload_identity_name}" + ) else: workload_identity_name = client.create_workload_identity()["name"] print("Created a workload identity") @@ -192,7 +197,9 @@ async def _set_up_local_auth(client: IdentityClient) -> str: except Exception: print("Warning: could not write the created workload identity to file") - return client.get_workload_access_token(workload_identity_name, user_id=user_id)["workloadAccessToken"] + return client.get_workload_access_token(workload_identity_name, user_id=user_id)[ + "workloadAccessToken" + ] def _get_region() -> str: diff --git a/src/bedrock_agentcore/services/identity.py b/src/bedrock_agentcore/services/identity.py index ffafc3f..14bae17 100644 --- a/src/bedrock_agentcore/services/identity.py +++ b/src/bedrock_agentcore/services/identity.py @@ -9,7 +9,10 @@ import boto3 -from bedrock_agentcore._utils.endpoints import get_control_plane_endpoint, get_data_plane_endpoint +from bedrock_agentcore._utils.endpoints import ( + get_control_plane_endpoint, + get_data_plane_endpoint, +) class TokenPoller(ABC): @@ -44,7 +47,9 @@ async def poll_for_token(self) -> str: while time.time() - start_time < DEFAULT_POLLING_TIMEOUT_SECONDS: await asyncio.sleep(DEFAULT_POLLING_INTERVAL_SECONDS) - self.logger.info("Polling for token for authorization url: %s", self.auth_url) + self.logger.info( + "Polling for token for authorization url: %s", self.auth_url + ) resp = self.polling_func() if resp is not None: self.logger.info("Token is ready") @@ -63,13 +68,19 @@ def __init__(self, region: str): """Initialize the identity client with the specified region.""" self.region = region self.cp_client = boto3.client( - "bedrock-agentcore-control", region_name=region, endpoint_url=get_control_plane_endpoint(region) + "bedrock-agentcore-control", + region_name=region, + endpoint_url=get_control_plane_endpoint(region), ) self.identity_client = boto3.client( - "bedrock-agentcore-control", region_name=region, endpoint_url=get_data_plane_endpoint(region) + "bedrock-agentcore-control", + region_name=region, + endpoint_url=get_data_plane_endpoint(region), ) self.dp_client = boto3.client( - "bedrock-agentcore", region_name=region, endpoint_url=get_data_plane_endpoint(region) + "bedrock-agentcore", + region_name=region, + endpoint_url=get_data_plane_endpoint(region), ) self.logger = logging.getLogger("bedrock_agentcore.identity_client") @@ -84,17 +95,26 @@ def create_api_key_credential_provider(self, req): return self.cp_client.create_api_key_credential_provider(**req) def get_workload_access_token( - self, workload_name: str, user_token: Optional[str] = None, user_id: Optional[str] = None + self, + workload_name: str, + user_token: Optional[str] = None, + user_id: Optional[str] = None, ) -> Dict: """Get a workload access token using workload name and optionally user token.""" if user_token: if user_id is not None: - self.logger.warning("Both user token and user id are supplied, using user token") + self.logger.warning( + "Both user token and user id are supplied, using user token" + ) self.logger.info("Getting workload access token for JWT...") - resp = self.dp_client.get_workload_access_token_for_jwt(workloadName=workload_name, userToken=user_token) + resp = self.dp_client.get_workload_access_token_for_jwt( + workloadName=workload_name, userToken=user_token + ) elif user_id: self.logger.info("Getting workload access token for user id...") - resp = self.dp_client.get_workload_access_token_for_user_id(workloadName=workload_name, userId=user_id) + resp = self.dp_client.get_workload_access_token_for_user_id( + workloadName=workload_name, userId=user_id + ) else: self.logger.info("Getting workload access token...") resp = self.dp_client.get_workload_access_token(workloadName=workload_name) @@ -118,6 +138,7 @@ async def get_token( on_auth_url: Optional[Callable[[str], Any]] = None, auth_flow: Literal["M2M", "USER_FEDERATION"], callback_url: Optional[str] = None, + custom_parameters: Optional[dict] = {}, force_authentication: bool = False, token_poller: Optional[TokenPoller] = None, ) -> str: @@ -130,6 +151,7 @@ async def get_token( on_auth_url: Callback for handling authorization URLs auth_flow: Authentication flow type ("M2M" or "USER_FEDERATION") callback_url: OAuth2 callback URL (must be pre-registered) + custom_parameters: optional parameters to be sent to the authorizer endpoint of the provider force_authentication: Force re-authentication even if token exists in the token vault token_poller: Custom token poller implementation @@ -148,6 +170,7 @@ async def get_token( "scopes": scopes, "oauth2Flow": auth_flow, "workloadIdentityToken": agent_identity_token, + "customParameters": custom_parameters, } # Add optional parameters @@ -178,15 +201,25 @@ async def get_token( # Poll for the token active_poller = token_poller or _DefaultApiTokenPoller( - auth_url, lambda: self.dp_client.get_resource_oauth2_token(**req).get("accessToken", None) + auth_url, + lambda: self.dp_client.get_resource_oauth2_token(**req).get( + "accessToken", None + ), ) return await active_poller.poll_for_token() - raise RuntimeError("Identity service did not return a token or an authorization URL.") + raise RuntimeError( + "Identity service did not return a token or an authorization URL." + ) - async def get_api_key(self, *, provider_name: str, agent_identity_token: str) -> str: + async def get_api_key( + self, *, provider_name: str, agent_identity_token: str + ) -> str: """Programmatically retrieves an API key from the Identity service.""" self.logger.info("Getting API key...") - req = {"resourceCredentialProviderName": provider_name, "workloadIdentityToken": agent_identity_token} + req = { + "resourceCredentialProviderName": provider_name, + "workloadIdentityToken": agent_identity_token, + } return self.dp_client.get_resource_api_key(**req)["apiKey"] diff --git a/tests/bedrock_agentcore/identity/test_auth.py b/tests/bedrock_agentcore/identity/test_auth.py index 34c0777..4ad7c09 100644 --- a/tests/bedrock_agentcore/identity/test_auth.py +++ b/tests/bedrock_agentcore/identity/test_auth.py @@ -22,13 +22,16 @@ class TestRequiresAccessTokenDecorator: async def test_async_function_decoration(self): """Test decorator with async function.""" # Mock IdentityClient - with patch("bedrock_agentcore.identity.auth.IdentityClient") as mock_identity_client_class: + with patch( + "bedrock_agentcore.identity.auth.IdentityClient" + ) as mock_identity_client_class: mock_client = Mock() mock_identity_client_class.return_value = mock_client # Mock _get_workload_access_token with patch( - "bedrock_agentcore.identity.auth._get_workload_access_token", new_callable=AsyncMock + "bedrock_agentcore.identity.auth._get_workload_access_token", + new_callable=AsyncMock, ) as mock_get_agent_token: mock_get_agent_token.return_value = "test-agent-token" @@ -36,9 +39,16 @@ async def test_async_function_decoration(self): mock_client.get_token = AsyncMock(return_value="test-access-token") # Mock _get_region - with patch("bedrock_agentcore.identity.auth._get_region", return_value="us-west-2"): + with patch( + "bedrock_agentcore.identity.auth._get_region", + return_value="us-west-2", + ): - @requires_access_token(provider_name="test-provider", scopes=["read", "write"], auth_flow="M2M") + @requires_access_token( + provider_name="test-provider", + scopes=["read", "write"], + auth_flow="M2M", + ) async def test_async_func(param1, access_token=None): return f"param1={param1}, token={access_token}" @@ -52,6 +62,7 @@ async def test_async_func(param1, access_token=None): on_auth_url=None, auth_flow="M2M", callback_url=None, + custom_parameters={}, force_authentication=False, token_poller=None, ) @@ -59,13 +70,16 @@ async def test_async_func(param1, access_token=None): def test_sync_function_decoration_no_running_loop(self): """Test decorator with sync function when no asyncio loop is running.""" # Mock IdentityClient - with patch("bedrock_agentcore.identity.auth.IdentityClient") as mock_identity_client_class: + with patch( + "bedrock_agentcore.identity.auth.IdentityClient" + ) as mock_identity_client_class: mock_client = Mock() mock_identity_client_class.return_value = mock_client # Mock _get_workload_access_token with patch( - "bedrock_agentcore.identity.auth._get_workload_access_token", new_callable=AsyncMock + "bedrock_agentcore.identity.auth._get_workload_access_token", + new_callable=AsyncMock, ) as mock_get_agent_token: mock_get_agent_token.return_value = "test-agent-token" @@ -73,14 +87,24 @@ def test_sync_function_decoration_no_running_loop(self): mock_client.get_token = AsyncMock(return_value="test-access-token") # Mock _get_region - with patch("bedrock_agentcore.identity.auth._get_region", return_value="us-west-2"): + with patch( + "bedrock_agentcore.identity.auth._get_region", + return_value="us-west-2", + ): - @requires_access_token(provider_name="test-provider", scopes=["read"], auth_flow="USER_FEDERATION") + @requires_access_token( + provider_name="test-provider", + scopes=["read"], + auth_flow="USER_FEDERATION", + ) def test_sync_func(param1, access_token=None): return f"param1={param1}, token={access_token}" # Mock asyncio.get_running_loop to raise RuntimeError (no loop) - with patch("asyncio.get_running_loop", side_effect=RuntimeError("no running loop")): + with patch( + "asyncio.get_running_loop", + side_effect=RuntimeError("no running loop"), + ): with patch("asyncio.run") as mock_asyncio_run: mock_asyncio_run.return_value = "test-access-token" @@ -92,13 +116,16 @@ def test_sync_func(param1, access_token=None): def test_sync_function_decoration_with_running_loop(self): """Test decorator with sync function when asyncio loop is running.""" # Mock IdentityClient - with patch("bedrock_agentcore.identity.auth.IdentityClient") as mock_identity_client_class: + with patch( + "bedrock_agentcore.identity.auth.IdentityClient" + ) as mock_identity_client_class: mock_client = Mock() mock_identity_client_class.return_value = mock_client # Mock _get_workload_access_token with patch( - "bedrock_agentcore.identity.auth._get_workload_access_token", new_callable=AsyncMock + "bedrock_agentcore.identity.auth._get_workload_access_token", + new_callable=AsyncMock, ) as mock_get_agent_token: mock_get_agent_token.return_value = "test-agent-token" @@ -106,17 +133,26 @@ def test_sync_function_decoration_with_running_loop(self): mock_client.get_token = AsyncMock(return_value="test-access-token") # Mock _get_region - with patch("bedrock_agentcore.identity.auth._get_region", return_value="us-west-2"): + with patch( + "bedrock_agentcore.identity.auth._get_region", + return_value="us-west-2", + ): - @requires_access_token(provider_name="test-provider", scopes=["read"], auth_flow="M2M") + @requires_access_token( + provider_name="test-provider", scopes=["read"], auth_flow="M2M" + ) def test_sync_func(param1, access_token=None): return f"param1={param1}, token={access_token}" # Mock asyncio.get_running_loop to succeed (loop is running) with patch("asyncio.get_running_loop"): - with patch("concurrent.futures.ThreadPoolExecutor") as mock_executor_class: + with patch( + "concurrent.futures.ThreadPoolExecutor" + ) as mock_executor_class: mock_executor = Mock() - mock_executor_class.return_value.__enter__.return_value = mock_executor + mock_executor_class.return_value.__enter__.return_value = ( + mock_executor + ) mock_future = Mock() mock_future.result.return_value = "test-access-token" @@ -131,13 +167,16 @@ def test_sync_func(param1, access_token=None): async def test_custom_parameter_name(self): """Test decorator with custom parameter name for token injection.""" # Mock IdentityClient - with patch("bedrock_agentcore.identity.auth.IdentityClient") as mock_identity_client_class: + with patch( + "bedrock_agentcore.identity.auth.IdentityClient" + ) as mock_identity_client_class: mock_client = Mock() mock_identity_client_class.return_value = mock_client # Mock _get_workload_access_token with patch( - "bedrock_agentcore.identity.auth._get_workload_access_token", new_callable=AsyncMock + "bedrock_agentcore.identity.auth._get_workload_access_token", + new_callable=AsyncMock, ) as mock_get_agent_token: mock_get_agent_token.return_value = "test-agent-token" @@ -145,10 +184,16 @@ async def test_custom_parameter_name(self): mock_client.get_token = AsyncMock(return_value="test-access-token") # Mock _get_region - with patch("bedrock_agentcore.identity.auth._get_region", return_value="us-west-2"): + with patch( + "bedrock_agentcore.identity.auth._get_region", + return_value="us-west-2", + ): @requires_access_token( - provider_name="test-provider", into="my_token", scopes=["read"], auth_flow="M2M" + provider_name="test-provider", + into="my_token", + scopes=["read"], + auth_flow="M2M", ) async def test_func(param1, my_token=None): return f"param1={param1}, token={my_token}" @@ -163,6 +208,7 @@ async def test_func(param1, my_token=None): on_auth_url=None, auth_flow="M2M", callback_url=None, + custom_parameters={}, force_authentication=False, token_poller=None, ) @@ -171,13 +217,16 @@ async def test_func(param1, my_token=None): async def test_with_all_optional_parameters(self): """Test decorator with all optional parameters.""" # Mock IdentityClient - with patch("bedrock_agentcore.identity.auth.IdentityClient") as mock_identity_client_class: + with patch( + "bedrock_agentcore.identity.auth.IdentityClient" + ) as mock_identity_client_class: mock_client = Mock() mock_identity_client_class.return_value = mock_client # Mock _get_workload_access_token with patch( - "bedrock_agentcore.identity.auth._get_workload_access_token", new_callable=AsyncMock + "bedrock_agentcore.identity.auth._get_workload_access_token", + new_callable=AsyncMock, ) as mock_get_agent_token: mock_get_agent_token.return_value = "test-agent-token" @@ -185,7 +234,10 @@ async def test_with_all_optional_parameters(self): mock_client.get_token = AsyncMock(return_value="test-access-token") # Mock _get_region - with patch("bedrock_agentcore.identity.auth._get_region", return_value="us-west-2"): + with patch( + "bedrock_agentcore.identity.auth._get_region", + return_value="us-west-2", + ): # Mock callback callback_called = False @@ -203,6 +255,7 @@ def on_auth_url(url): on_auth_url=on_auth_url, auth_flow="USER_FEDERATION", callback_url="https://example.com/callback", + custom_parameters={"audience": "Audience"}, force_authentication=True, token_poller=mock_poller, ) @@ -218,6 +271,7 @@ async def test_func(token=None): scopes=["read", "write"], on_auth_url=on_auth_url, auth_flow="USER_FEDERATION", + custom_parameters={"audience": "Audience"}, callback_url="https://example.com/callback", force_authentication=True, token_poller=mock_poller, @@ -231,13 +285,16 @@ class TestRequiresApiKeyDecorator: async def test_async_function_decoration(self): """Test decorator with async function.""" # Mock IdentityClient - with patch("bedrock_agentcore.identity.auth.IdentityClient") as mock_identity_client_class: + with patch( + "bedrock_agentcore.identity.auth.IdentityClient" + ) as mock_identity_client_class: mock_client = Mock() mock_identity_client_class.return_value = mock_client # Mock _get_workload_access_token with patch( - "bedrock_agentcore.identity.auth._get_workload_access_token", new_callable=AsyncMock + "bedrock_agentcore.identity.auth._get_workload_access_token", + new_callable=AsyncMock, ) as mock_get_agent_token: mock_get_agent_token.return_value = "test-agent-token" @@ -245,7 +302,10 @@ async def test_async_function_decoration(self): mock_client.get_api_key = AsyncMock(return_value="test-api-key") # Mock _get_region - with patch("bedrock_agentcore.identity.auth._get_region", return_value="us-west-2"): + with patch( + "bedrock_agentcore.identity.auth._get_region", + return_value="us-west-2", + ): @requires_api_key(provider_name="test-provider") async def test_async_func(param1, api_key=None): @@ -255,19 +315,23 @@ async def test_async_func(param1, api_key=None): assert result == "param1=value1, key=test-api-key" mock_client.get_api_key.assert_called_once_with( - provider_name="test-provider", agent_identity_token="test-agent-token" + provider_name="test-provider", + agent_identity_token="test-agent-token", ) def test_sync_function_decoration_no_running_loop(self): """Test decorator with sync function when no asyncio loop is running.""" # Mock IdentityClient - with patch("bedrock_agentcore.identity.auth.IdentityClient") as mock_identity_client_class: + with patch( + "bedrock_agentcore.identity.auth.IdentityClient" + ) as mock_identity_client_class: mock_client = Mock() mock_identity_client_class.return_value = mock_client # Mock _get_workload_access_token with patch( - "bedrock_agentcore.identity.auth._get_workload_access_token", new_callable=AsyncMock + "bedrock_agentcore.identity.auth._get_workload_access_token", + new_callable=AsyncMock, ) as mock_get_agent_token: mock_get_agent_token.return_value = "test-agent-token" @@ -275,14 +339,20 @@ def test_sync_function_decoration_no_running_loop(self): mock_client.get_api_key = AsyncMock(return_value="test-api-key") # Mock _get_region - with patch("bedrock_agentcore.identity.auth._get_region", return_value="us-west-2"): + with patch( + "bedrock_agentcore.identity.auth._get_region", + return_value="us-west-2", + ): @requires_api_key(provider_name="test-provider", into="my_key") def test_sync_func(param1, my_key=None): return f"param1={param1}, key={my_key}" # Mock asyncio.get_running_loop to raise RuntimeError (no loop) - with patch("asyncio.get_running_loop", side_effect=RuntimeError("no running loop")): + with patch( + "asyncio.get_running_loop", + side_effect=RuntimeError("no running loop"), + ): with patch("asyncio.run") as mock_asyncio_run: mock_asyncio_run.return_value = "test-api-key" @@ -293,13 +363,16 @@ def test_sync_func(param1, my_key=None): def test_sync_function_decoration_with_running_loop(self): """Test decorator with sync function when asyncio loop is running.""" # Mock IdentityClient - with patch("bedrock_agentcore.identity.auth.IdentityClient") as mock_identity_client_class: + with patch( + "bedrock_agentcore.identity.auth.IdentityClient" + ) as mock_identity_client_class: mock_client = Mock() mock_identity_client_class.return_value = mock_client # Mock _get_workload_access_token with patch( - "bedrock_agentcore.identity.auth._get_workload_access_token", new_callable=AsyncMock + "bedrock_agentcore.identity.auth._get_workload_access_token", + new_callable=AsyncMock, ) as mock_get_agent_token: mock_get_agent_token.return_value = "test-agent-token" @@ -307,7 +380,10 @@ def test_sync_function_decoration_with_running_loop(self): mock_client.get_api_key = AsyncMock(return_value="test-api-key") # Mock _get_region - with patch("bedrock_agentcore.identity.auth._get_region", return_value="us-west-2"): + with patch( + "bedrock_agentcore.identity.auth._get_region", + return_value="us-west-2", + ): @requires_api_key(provider_name="test-provider") def test_sync_func(param1, api_key=None): @@ -315,9 +391,13 @@ def test_sync_func(param1, api_key=None): # Mock asyncio.get_running_loop to succeed (loop is running) with patch("asyncio.get_running_loop"): - with patch("concurrent.futures.ThreadPoolExecutor") as mock_executor_class: + with patch( + "concurrent.futures.ThreadPoolExecutor" + ) as mock_executor_class: mock_executor = Mock() - mock_executor_class.return_value.__enter__.return_value = mock_executor + mock_executor_class.return_value.__enter__.return_value = ( + mock_executor + ) mock_future = Mock() mock_future.result.return_value = "test-api-key" @@ -335,9 +415,14 @@ class TestSetUpLocalAuth: @pytest.mark.asyncio async def test_existing_config(self, tmp_path): """Test when config file exists with both workload_identity_name and user_id.""" - config_content = {"workload_identity_name": "existing-workload-123", "user_id": "existing-user-456"} + config_content = { + "workload_identity_name": "existing-workload-123", + "user_id": "existing-user-456", + } mock_client = Mock() - mock_client.get_workload_access_token = Mock(return_value={"workloadAccessToken": "test-access-token-456"}) + mock_client.get_workload_access_token = Mock( + return_value={"workloadAccessToken": "test-access-token-456"} + ) # Create the config file in the temp directory config_file = tmp_path / ".agentcore.json" @@ -364,8 +449,12 @@ async def test_existing_config(self, tmp_path): async def test_no_config(self, tmp_path): """Test when config file doesn't exist.""" mock_client = Mock() - mock_client.create_workload_identity = Mock(return_value={"name": "test-workload-123"}) - mock_client.get_workload_access_token = Mock(return_value={"workloadAccessToken": "test-access-token-456"}) + mock_client.create_workload_identity = Mock( + return_value={"name": "test-workload-123"} + ) + mock_client.get_workload_access_token = Mock( + return_value={"workloadAccessToken": "test-access-token-456"} + ) # Change to the temp directory for the test import os @@ -382,7 +471,9 @@ async def test_no_config(self, tmp_path): # Should create new workload identity and user_id assert result == "test-access-token-456" mock_client.create_workload_identity.assert_called_once() - mock_client.get_workload_access_token.assert_called_once_with("test-workload-123", user_id="abcd1234") + mock_client.get_workload_access_token.assert_called_once_with( + "test-workload-123", user_id="abcd1234" + ) # Verify that the config file was created config_file = tmp_path / ".agentcore.json" @@ -448,7 +539,10 @@ async def test_no_context_local_dev(self): with patch("os.getenv") as mock_getenv: mock_getenv.return_value = None # Not in Docker - with patch("bedrock_agentcore.identity.auth._set_up_local_auth", new_callable=AsyncMock) as mock_setup: + with patch( + "bedrock_agentcore.identity.auth._set_up_local_auth", + new_callable=AsyncMock, + ) as mock_setup: mock_setup.return_value = "local-dev-token-456" result = await _get_workload_access_token(mock_client) @@ -470,7 +564,9 @@ async def test_no_context_docker_container(self): with patch("os.getenv") as mock_getenv: mock_getenv.return_value = "1" # In Docker container - with pytest.raises(ValueError, match="Workload access token has not been set"): + with pytest.raises( + ValueError, match="Workload access token has not been set" + ): await _get_workload_access_token(mock_client) mock_get_token.assert_called_once() diff --git a/tests/bedrock_agentcore/services/test_identity.py b/tests/bedrock_agentcore/services/test_identity.py index 4c0b723..3015ea6 100644 --- a/tests/bedrock_agentcore/services/test_identity.py +++ b/tests/bedrock_agentcore/services/test_identity.py @@ -43,7 +43,9 @@ def test_create_oauth2_credential_provider(self): # Test data req = {"name": "test-provider", "clientId": "test-client"} expected_response = {"providerId": "test-provider-id"} - mock_client.create_oauth2_credential_provider.return_value = expected_response + mock_client.create_oauth2_credential_provider.return_value = ( + expected_response + ) result = identity_client.create_oauth2_credential_provider(req) @@ -63,12 +65,16 @@ def test_create_api_key_credential_provider(self): # Test data req = {"name": "test-api-provider", "apiKeyName": "test-key"} expected_response = {"providerId": "test-api-provider-id"} - mock_client.create_api_key_credential_provider.return_value = expected_response + mock_client.create_api_key_credential_provider.return_value = ( + expected_response + ) result = identity_client.create_api_key_credential_provider(req) assert result == expected_response - mock_client.create_api_key_credential_provider.assert_called_once_with(**req) + mock_client.create_api_key_credential_provider.assert_called_once_with( + **req + ) @pytest.mark.asyncio async def test_get_token_direct_response(self): @@ -87,10 +93,15 @@ async def test_get_token_direct_response(self): agent_identity_token = "test-agent-token" expected_token = "test-access-token" - mock_client.get_resource_oauth2_token.return_value = {"accessToken": expected_token} + mock_client.get_resource_oauth2_token.return_value = { + "accessToken": expected_token + } result = await identity_client.get_token( - provider_name=provider_name, scopes=scopes, agent_identity_token=agent_identity_token, auth_flow="M2M" + provider_name=provider_name, + scopes=scopes, + agent_identity_token=agent_identity_token, + auth_flow="M2M", ) assert result == expected_token @@ -98,6 +109,7 @@ async def test_get_token_direct_response(self): resourceCredentialProviderName=provider_name, scopes=scopes, oauth2Flow="M2M", + customParameters={}, workloadIdentityToken=agent_identity_token, ) @@ -128,9 +140,14 @@ async def test_get_token_with_auth_url_polling(self): mock_poller = Mock() mock_poller.poll_for_token = AsyncMock(return_value=expected_token) - with patch("bedrock_agentcore.services.identity._DefaultApiTokenPoller", return_value=mock_poller): + with patch( + "bedrock_agentcore.services.identity._DefaultApiTokenPoller", + return_value=mock_poller, + ): result = await identity_client.get_token( - provider_name=provider_name, agent_identity_token=agent_identity_token, auth_flow="USER_FEDERATION" + provider_name=provider_name, + agent_identity_token=agent_identity_token, + auth_flow="USER_FEDERATION", ) assert result == expected_token @@ -161,13 +178,18 @@ def on_auth_url(url): callback_called = True assert url == auth_url - mock_client.get_resource_oauth2_token.return_value = {"authorizationUrl": auth_url} + mock_client.get_resource_oauth2_token.return_value = { + "authorizationUrl": auth_url + } # Mock the token poller mock_poller = Mock() mock_poller.poll_for_token = AsyncMock(return_value=expected_token) - with patch("bedrock_agentcore.services.identity._DefaultApiTokenPoller", return_value=mock_poller): + with patch( + "bedrock_agentcore.services.identity._DefaultApiTokenPoller", + return_value=mock_poller, + ): result = await identity_client.get_token( provider_name=provider_name, agent_identity_token=agent_identity_token, @@ -204,13 +226,18 @@ async def on_auth_url(url): callback_called = True assert url == auth_url - mock_client.get_resource_oauth2_token.return_value = {"authorizationUrl": auth_url} + mock_client.get_resource_oauth2_token.return_value = { + "authorizationUrl": auth_url + } # Mock the token poller mock_poller = Mock() mock_poller.poll_for_token = AsyncMock(return_value=expected_token) - with patch("bedrock_agentcore.services.identity._DefaultApiTokenPoller", return_value=mock_poller): + with patch( + "bedrock_agentcore.services.identity._DefaultApiTokenPoller", + return_value=mock_poller, + ): result = await identity_client.get_token( provider_name=provider_name, agent_identity_token=agent_identity_token, @@ -239,9 +266,12 @@ async def test_get_token_with_optional_parameters(self): agent_identity_token = "test-agent-token" callback_url = "https://example.com/callback" force_authentication = True + custom_parameters = {"audience": "Audience"} expected_token = "test-access-token" - mock_client.get_resource_oauth2_token.return_value = {"accessToken": expected_token} + mock_client.get_resource_oauth2_token.return_value = { + "accessToken": expected_token + } result = await identity_client.get_token( provider_name=provider_name, @@ -249,6 +279,7 @@ async def test_get_token_with_optional_parameters(self): agent_identity_token=agent_identity_token, auth_flow="USER_FEDERATION", callback_url=callback_url, + custom_parameters=custom_parameters, force_authentication=force_authentication, ) @@ -259,6 +290,7 @@ async def test_get_token_with_optional_parameters(self): oauth2Flow="USER_FEDERATION", workloadIdentityToken=agent_identity_token, resourceOauth2ReturnUrl=callback_url, + customParameters=custom_parameters, forceAuthentication=force_authentication, ) @@ -279,7 +311,9 @@ async def test_get_token_with_custom_token_poller(self): auth_url = "https://example.com/auth" expected_token = "test-access-token" - mock_client.get_resource_oauth2_token.return_value = {"authorizationUrl": auth_url} + mock_client.get_resource_oauth2_token.return_value = { + "authorizationUrl": auth_url + } # Mock custom token poller custom_poller = Mock() @@ -319,7 +353,8 @@ async def test_get_api_key_success(self): assert result == expected_api_key mock_client.get_resource_api_key.assert_called_once_with( - resourceCredentialProviderName=provider_name, workloadIdentityToken=agent_identity_token + resourceCredentialProviderName=provider_name, + workloadIdentityToken=agent_identity_token, ) def test_get_workload_access_token_with_user_token(self): @@ -330,19 +365,29 @@ def test_get_workload_access_token_with_user_token(self): mock_cp_client = Mock() mock_identity_client = Mock() mock_dp_client = Mock() - mock_boto_client.side_effect = [mock_cp_client, mock_identity_client, mock_dp_client] + mock_boto_client.side_effect = [ + mock_cp_client, + mock_identity_client, + mock_dp_client, + ] identity_client = IdentityClient(region) # Test data workload_name = "test-workload" user_token = "test-user-jwt-token" - user_id = "test-user-id" # This should be ignored when user_token is provided + user_id = ( + "test-user-id" # This should be ignored when user_token is provided + ) expected_response = {"workloadAccessToken": "test-workload-token"} - mock_dp_client.get_workload_access_token_for_jwt.return_value = expected_response + mock_dp_client.get_workload_access_token_for_jwt.return_value = ( + expected_response + ) - result = identity_client.get_workload_access_token(workload_name, user_token=user_token, user_id=user_id) + result = identity_client.get_workload_access_token( + workload_name, user_token=user_token, user_id=user_id + ) assert result == expected_response mock_dp_client.get_workload_access_token_for_jwt.assert_called_once_with( @@ -360,7 +405,11 @@ def test_get_workload_access_token_with_user_id(self): mock_cp_client = Mock() mock_identity_client = Mock() mock_dp_client = Mock() - mock_boto_client.side_effect = [mock_cp_client, mock_identity_client, mock_dp_client] + mock_boto_client.side_effect = [ + mock_cp_client, + mock_identity_client, + mock_dp_client, + ] identity_client = IdentityClient(region) @@ -369,9 +418,13 @@ def test_get_workload_access_token_with_user_id(self): user_id = "test-user-id" expected_response = {"workloadAccessToken": "test-workload-token"} - mock_dp_client.get_workload_access_token_for_user_id.return_value = expected_response + mock_dp_client.get_workload_access_token_for_user_id.return_value = ( + expected_response + ) - result = identity_client.get_workload_access_token(workload_name, user_id=user_id) + result = identity_client.get_workload_access_token( + workload_name, user_id=user_id + ) assert result == expected_response mock_dp_client.get_workload_access_token_for_user_id.assert_called_once_with( @@ -389,7 +442,11 @@ def test_get_workload_access_token_without_user_info(self): mock_cp_client = Mock() mock_identity_client = Mock() mock_dp_client = Mock() - mock_boto_client.side_effect = [mock_cp_client, mock_identity_client, mock_dp_client] + mock_boto_client.side_effect = [ + mock_cp_client, + mock_identity_client, + mock_dp_client, + ] identity_client = IdentityClient(region) @@ -402,7 +459,9 @@ def test_get_workload_access_token_without_user_info(self): result = identity_client.get_workload_access_token(workload_name) assert result == expected_response - mock_dp_client.get_workload_access_token.assert_called_once_with(workloadName=workload_name) + mock_dp_client.get_workload_access_token.assert_called_once_with( + workloadName=workload_name + ) # Should not call user-specific versions mock_dp_client.get_workload_access_token_for_jwt.assert_not_called() mock_dp_client.get_workload_access_token_for_user_id.assert_not_called() @@ -415,24 +474,40 @@ def test_create_workload_identity(self): mock_cp_client = Mock() mock_identity_client = Mock() mock_dp_client = Mock() - mock_boto_client.side_effect = [mock_cp_client, mock_identity_client, mock_dp_client] + mock_boto_client.side_effect = [ + mock_cp_client, + mock_identity_client, + mock_dp_client, + ] identity_client = IdentityClient(region) # Test with provided name custom_name = "my-custom-workload" - expected_response = {"name": custom_name, "workloadIdentityId": "workload-123"} - mock_identity_client.create_workload_identity.return_value = expected_response + expected_response = { + "name": custom_name, + "workloadIdentityId": "workload-123", + } + mock_identity_client.create_workload_identity.return_value = ( + expected_response + ) result = identity_client.create_workload_identity(name=custom_name) assert result == expected_response - mock_identity_client.create_workload_identity.assert_called_with(name=custom_name) + mock_identity_client.create_workload_identity.assert_called_with( + name=custom_name + ) # Test without provided name (auto-generated) mock_identity_client.reset_mock() - expected_response_auto = {"name": "workload-abcd1234", "workloadIdentityId": "workload-456"} - mock_identity_client.create_workload_identity.return_value = expected_response_auto + expected_response_auto = { + "name": "workload-abcd1234", + "workloadIdentityId": "workload-456", + } + mock_identity_client.create_workload_identity.return_value = ( + expected_response_auto + ) with patch("uuid.uuid4") as mock_uuid: mock_uuid.return_value.hex = "abcd1234efgh5678" @@ -440,7 +515,9 @@ def test_create_workload_identity(self): result = identity_client.create_workload_identity() assert result == expected_response_auto - mock_identity_client.create_workload_identity.assert_called_with(name="workload-abcd1234") + mock_identity_client.create_workload_identity.assert_called_with( + name="workload-abcd1234" + ) class TestDefaultApiTokenPoller: @@ -508,4 +585,6 @@ async def test_poll_for_token_timeout(self): await poller.poll_for_token() assert "Polling timed out" in str(exc_info.value) - assert f"{DEFAULT_POLLING_TIMEOUT_SECONDS} seconds" in str(exc_info.value) + assert f"{DEFAULT_POLLING_TIMEOUT_SECONDS} seconds" in str( + exc_info.value + ) diff --git a/tests_integ/identity/test_auth_flows_auth0.py b/tests_integ/identity/test_auth_flows_auth0.py new file mode 100644 index 0000000..3fa8e5d --- /dev/null +++ b/tests_integ/identity/test_auth_flows_auth0.py @@ -0,0 +1,34 @@ +import asyncio + +from bedrock_agentcore.identity.auth import requires_access_token, requires_api_key + + +@requires_access_token( + provider_name="auth0_3lo", # replace with your Auth0 credential provider name + scopes=["list"], + auth_flow="USER_FEDERATION", + on_auth_url=lambda x: print(x), + custom_parameters={ + "audience": "Auth0Gateway" + }, # replace with the audience associated to your API + force_authentication=True, +) +async def need_token_3LO_async(*, access_token: str): + print(access_token) + + +@requires_access_token( + provider_name="auth0_2lo", # replace with your Auth0 credential provider name + scopes=[], + custom_parameters={ + "audience": "Auth0Gateway" + }, # replace with the audience associated to your API + auth_flow="M2M", +) +async def need_token_2LO_async(*, access_token: str): + print(f"received 2LO token for async func: {access_token}") + + +if __name__ == "__main__": + asyncio.run(need_token_2LO_async(access_token="")) + asyncio.run(need_token_3LO_async(access_token=""))