Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 9 additions & 2 deletions src/bedrock_agentcore/identity/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ def requires_access_token(
on_auth_url: Optional[Callable[[str], Any]] = None,
auth_flow: Literal["M2M", "USER_FEDERATION"],
callback_url: Optional[str] = None,
custom_parameters: Optional[dict] = {},
force_authentication: bool = False,
token_poller: Optional[TokenPoller] = None,
) -> Callable:
Expand All @@ -38,6 +39,7 @@ def requires_access_token(
on_auth_url: Callback for handling authorization URLs
auth_flow: Authentication flow type ("M2M" or "USER_FEDERATION")
callback_url: OAuth2 callback URL
custom_parameters: optional parameters to be sent to the authorizer endpoint of the provider
force_authentication: Force re-authentication
token_poller: Custom token poller implementation

Expand All @@ -57,6 +59,7 @@ async def _get_token() -> str:
on_auth_url=on_auth_url,
auth_flow=auth_flow,
callback_url=callback_url,
custom_parameters=custom_parameters,
force_authentication=force_authentication,
token_poller=token_poller,
)
Expand Down Expand Up @@ -173,7 +176,9 @@ async def _set_up_local_auth(client: IdentityClient) -> str:

workload_identity_name = config.get("workload_identity_name")
if workload_identity_name:
print(f"Found existing workload identity from {config_path.absolute()}: {workload_identity_name}")
print(
f"Found existing workload identity from {config_path.absolute()}: {workload_identity_name}"
)
else:
workload_identity_name = client.create_workload_identity()["name"]
print("Created a workload identity")
Expand All @@ -192,7 +197,9 @@ async def _set_up_local_auth(client: IdentityClient) -> str:
except Exception:
print("Warning: could not write the created workload identity to file")

return client.get_workload_access_token(workload_identity_name, user_id=user_id)["workloadAccessToken"]
return client.get_workload_access_token(workload_identity_name, user_id=user_id)[
"workloadAccessToken"
]


def _get_region() -> str:
Expand Down
59 changes: 46 additions & 13 deletions src/bedrock_agentcore/services/identity.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,10 @@

import boto3

from bedrock_agentcore._utils.endpoints import get_control_plane_endpoint, get_data_plane_endpoint
from bedrock_agentcore._utils.endpoints import (
get_control_plane_endpoint,
get_data_plane_endpoint,
)


class TokenPoller(ABC):
Expand Down Expand Up @@ -44,7 +47,9 @@ async def poll_for_token(self) -> str:
while time.time() - start_time < DEFAULT_POLLING_TIMEOUT_SECONDS:
await asyncio.sleep(DEFAULT_POLLING_INTERVAL_SECONDS)

self.logger.info("Polling for token for authorization url: %s", self.auth_url)
self.logger.info(
"Polling for token for authorization url: %s", self.auth_url
)
resp = self.polling_func()
if resp is not None:
self.logger.info("Token is ready")
Expand All @@ -63,13 +68,19 @@ def __init__(self, region: str):
"""Initialize the identity client with the specified region."""
self.region = region
self.cp_client = boto3.client(
"bedrock-agentcore-control", region_name=region, endpoint_url=get_control_plane_endpoint(region)
"bedrock-agentcore-control",
region_name=region,
endpoint_url=get_control_plane_endpoint(region),
)
self.identity_client = boto3.client(
"bedrock-agentcore-control", region_name=region, endpoint_url=get_data_plane_endpoint(region)
"bedrock-agentcore-control",
region_name=region,
endpoint_url=get_data_plane_endpoint(region),
)
self.dp_client = boto3.client(
"bedrock-agentcore", region_name=region, endpoint_url=get_data_plane_endpoint(region)
"bedrock-agentcore",
region_name=region,
endpoint_url=get_data_plane_endpoint(region),
)
self.logger = logging.getLogger("bedrock_agentcore.identity_client")

Expand All @@ -84,17 +95,26 @@ def create_api_key_credential_provider(self, req):
return self.cp_client.create_api_key_credential_provider(**req)

def get_workload_access_token(
self, workload_name: str, user_token: Optional[str] = None, user_id: Optional[str] = None
self,
workload_name: str,
user_token: Optional[str] = None,
user_id: Optional[str] = None,
) -> Dict:
"""Get a workload access token using workload name and optionally user token."""
if user_token:
if user_id is not None:
self.logger.warning("Both user token and user id are supplied, using user token")
self.logger.warning(
"Both user token and user id are supplied, using user token"
)
self.logger.info("Getting workload access token for JWT...")
resp = self.dp_client.get_workload_access_token_for_jwt(workloadName=workload_name, userToken=user_token)
resp = self.dp_client.get_workload_access_token_for_jwt(
workloadName=workload_name, userToken=user_token
)
elif user_id:
self.logger.info("Getting workload access token for user id...")
resp = self.dp_client.get_workload_access_token_for_user_id(workloadName=workload_name, userId=user_id)
resp = self.dp_client.get_workload_access_token_for_user_id(
workloadName=workload_name, userId=user_id
)
else:
self.logger.info("Getting workload access token...")
resp = self.dp_client.get_workload_access_token(workloadName=workload_name)
Expand All @@ -118,6 +138,7 @@ async def get_token(
on_auth_url: Optional[Callable[[str], Any]] = None,
auth_flow: Literal["M2M", "USER_FEDERATION"],
callback_url: Optional[str] = None,
custom_parameters: Optional[dict] = {},
force_authentication: bool = False,
token_poller: Optional[TokenPoller] = None,
) -> str:
Expand All @@ -130,6 +151,7 @@ async def get_token(
on_auth_url: Callback for handling authorization URLs
auth_flow: Authentication flow type ("M2M" or "USER_FEDERATION")
callback_url: OAuth2 callback URL (must be pre-registered)
custom_parameters: optional parameters to be sent to the authorizer endpoint of the provider
force_authentication: Force re-authentication even if token exists in the token vault
token_poller: Custom token poller implementation

Expand All @@ -148,6 +170,7 @@ async def get_token(
"scopes": scopes,
"oauth2Flow": auth_flow,
"workloadIdentityToken": agent_identity_token,
"customParameters": custom_parameters,
}

# Add optional parameters
Expand Down Expand Up @@ -178,15 +201,25 @@ async def get_token(

# Poll for the token
active_poller = token_poller or _DefaultApiTokenPoller(
auth_url, lambda: self.dp_client.get_resource_oauth2_token(**req).get("accessToken", None)
auth_url,
lambda: self.dp_client.get_resource_oauth2_token(**req).get(
"accessToken", None
),
)
return await active_poller.poll_for_token()

raise RuntimeError("Identity service did not return a token or an authorization URL.")
raise RuntimeError(
"Identity service did not return a token or an authorization URL."
)

async def get_api_key(self, *, provider_name: str, agent_identity_token: str) -> str:
async def get_api_key(
self, *, provider_name: str, agent_identity_token: str
) -> str:
"""Programmatically retrieves an API key from the Identity service."""
self.logger.info("Getting API key...")
req = {"resourceCredentialProviderName": provider_name, "workloadIdentityToken": agent_identity_token}
req = {
"resourceCredentialProviderName": provider_name,
"workloadIdentityToken": agent_identity_token,
}

return self.dp_client.get_resource_api_key(**req)["apiKey"]
Loading
Loading