Skip to content

Commit

Permalink
Support assume role for aws base async hook
Browse files Browse the repository at this point in the history
  • Loading branch information
bharanidharan14 committed Dec 5, 2022
1 parent 3743e09 commit 2681cce
Showing 1 changed file with 45 additions and 4 deletions.
49 changes: 45 additions & 4 deletions astronomer/providers/amazon/aws/hooks/base_aws.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from typing import Any, Dict

from aiobotocore.client import AioBaseClient
from aiobotocore.session import get_session
from aiobotocore.session import AioSession, get_session
from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook
from airflow.providers.amazon.aws.utils.connection_wrapper import AwsConnectionWrapper
from asgiref.sync import sync_to_async
Expand Down Expand Up @@ -44,12 +46,51 @@ async def get_client_async(self) -> AioBaseClient:
)

async_connection = get_session()
session_token = conn_config.aws_session_token
aws_secret = conn_config.aws_secret_access_key
aws_access = conn_config.aws_access_key_id
if conn_config.role_arn:
credentials = await self.get_role_credentials(
async_session=async_connection, conn_config=conn_config
)
session_token = credentials["SessionToken"]
aws_access = credentials["AccessKeyId"]
aws_secret = credentials["SecretAccessKey"]
return async_connection.create_client(
service_name=self.client_type,
region_name=conn_config.region_name,
aws_secret_access_key=conn_config.aws_secret_access_key,
aws_access_key_id=conn_config.aws_access_key_id,
aws_session_token=conn_config.aws_session_token,
aws_secret_access_key=aws_access,
aws_access_key_id=aws_secret,
aws_session_token=session_token,
verify=self.verify,
config=self.config,
endpoint_url=conn_config.endpoint_url,
)

@staticmethod
async def get_role_credentials(async_session: AioSession, conn_config) -> Dict[str, Any]:
"""Get the role_arn, method credentials from connection details and get the role credentials detail"""
async with async_session.create_client(
"sts",
aws_access_key_id=conn_config.aws_access_key_id,
aws_secret_access_key=conn_config.aws_secret_access_key,
) as client:
if conn_config.assume_role_method == "assume_role_with_saml":
response = await client.assume_role_with_saml(
RoleArn=conn_config.role_arn,
RoleSessionName="RoleSession",
**conn_config.assume_role_kwargs,
)
elif conn_config.assume_role_method == "assume_role_with_web_identity":
response = await client.assume_role_with_web_identity(
RoleArn=conn_config.role_arn,
RoleSessionName="RoleSession",
**conn_config.assume_role_kwargs,
)
else:
response = await client.assume_role(
RoleArn=conn_config.role_arn,
RoleSessionName="RoleSession",
**conn_config.assume_role_kwargs,
)
return response["Credentials"]

0 comments on commit 2681cce

Please sign in to comment.