From d390a2fa3882940212a1e0585dd83fd7d22bf346 Mon Sep 17 00:00:00 2001 From: Abdel Jaidi Date: Wed, 6 Apr 2022 18:57:50 +0100 Subject: [PATCH] enhancement(data-api): Add boto3 session to connect --- awswrangler/data_api/rds.py | 14 +++++++++++--- awswrangler/data_api/redshift.py | 21 ++++++++++++++++++--- tests/test_data_api.py | 5 +++-- 3 files changed, 32 insertions(+), 8 deletions(-) diff --git a/awswrangler/data_api/rds.py b/awswrangler/data_api/rds.py index b1986706b..0cbc0b590 100644 --- a/awswrangler/data_api/rds.py +++ b/awswrangler/data_api/rds.py @@ -7,6 +7,7 @@ import boto3 import pandas as pd +from awswrangler import _utils from awswrangler.data_api import connector @@ -27,6 +28,8 @@ class RdsDataApi(connector.DataApiConnector): Factor by which to increase the sleep between connection attempts to paused clusters - defaults to 1.0. retries: int Maximum number of connection attempts to paused clusters - defaults to 10. + boto3_session : boto3.Session(), optional + The boto3 session. If `None`, the default boto3 session is used. """ def __init__( @@ -37,12 +40,13 @@ def __init__( sleep: float = 0.5, backoff: float = 1.0, retries: int = 30, + boto3_session: Optional[boto3.Session] = None, ) -> None: self.resource_arn = resource_arn self.database = database self.secret_arn = secret_arn self.wait_config = connector.WaitConfig(sleep, backoff, retries) - self.client = boto3.client("rds-data") + self.client: boto3.client = _utils.client(service_name="rds-data", session=boto3_session) self.results: Dict[str, Dict[str, Any]] = {} logger: logging.Logger = logging.getLogger(__name__) super().__init__(self.client, logger) @@ -114,7 +118,9 @@ def _get_statement_result(self, request_id: str) -> pd.DataFrame: return dataframe -def connect(resource_arn: str, database: str, secret_arn: str = "", **kwargs: Any) -> RdsDataApi: +def connect( + resource_arn: str, database: str, secret_arn: str = "", boto3_session: Optional[boto3.Session] = None, **kwargs: Any +) -> RdsDataApi: """Create a RDS Data API connection. Parameters @@ -125,6 +131,8 @@ def connect(resource_arn: str, database: str, secret_arn: str = "", **kwargs: An Target database name. secret_arn: str The ARN for the secret to be used for authentication. + boto3_session : boto3.Session(), optional + The boto3 session. If `None`, the default boto3 session is used. **kwargs Any additional kwargs are passed to the underlying RdsDataApi class. @@ -132,7 +140,7 @@ def connect(resource_arn: str, database: str, secret_arn: str = "", **kwargs: An ------- A RdsDataApi connection instance that can be used with `wr.rds.data_api.read_sql_query`. """ - return RdsDataApi(resource_arn, database, secret_arn=secret_arn, **kwargs) + return RdsDataApi(resource_arn, database, secret_arn=secret_arn, boto3_session=boto3_session, **kwargs) def read_sql_query(sql: str, con: RdsDataApi, database: Optional[str] = None) -> pd.DataFrame: diff --git a/awswrangler/data_api/redshift.py b/awswrangler/data_api/redshift.py index a82f5e7cc..8642a47d9 100644 --- a/awswrangler/data_api/redshift.py +++ b/awswrangler/data_api/redshift.py @@ -6,6 +6,7 @@ import boto3 import pandas as pd +from awswrangler import _utils from awswrangler.data_api import connector @@ -28,6 +29,8 @@ class RedshiftDataApi(connector.DataApiConnector): Factor by which to increase the sleep between result fetch attempts - defaults to 1.5. retries: int Maximum number of result fetch attempts - defaults to 15. + boto3_session : boto3.Session(), optional + The boto3 session. If `None`, the default boto3 session is used. """ def __init__( @@ -39,12 +42,13 @@ def __init__( sleep: float = 0.25, backoff: float = 1.5, retries: int = 15, + boto3_session: Optional[boto3.Session] = None, ) -> None: self.cluster_id = cluster_id self.database = database self.secret_arn = secret_arn self.db_user = db_user - self.client = boto3.client("redshift-data") + self.client: boto3.client = _utils.client(service_name="redshift-data", session=boto3_session) self.waiter = RedshiftDataApiWaiter(self.client, sleep, backoff, retries) logger: logging.Logger = logging.getLogger(__name__) super().__init__(self.client, logger) @@ -162,7 +166,14 @@ class RedshiftDataApiTimeoutException(Exception): """Indicates a statement execution did not complete in the expected wait time.""" -def connect(cluster_id: str, database: str, secret_arn: str = "", db_user: str = "", **kwargs: Any) -> RedshiftDataApi: +def connect( + cluster_id: str, + database: str, + secret_arn: str = "", + db_user: str = "", + boto3_session: Optional[boto3.Session] = None, + **kwargs: Any, +) -> RedshiftDataApi: """Create a Redshift Data API connection. Parameters @@ -175,6 +186,8 @@ def connect(cluster_id: str, database: str, secret_arn: str = "", db_user: str = The ARN for the secret to be used for authentication - only required if `db_user` not provided. db_user: str The database user to generate temporary credentials for - only required if `secret_arn` not provided. + boto3_session : boto3.Session(), optional + The boto3 session. If `None`, the default boto3 session is used. **kwargs Any additional kwargs are passed to the underlying RedshiftDataApi class. @@ -182,7 +195,9 @@ def connect(cluster_id: str, database: str, secret_arn: str = "", db_user: str = ------- A RedshiftDataApi connection instance that can be used with `wr.redshift.data_api.read_sql_query`. """ - return RedshiftDataApi(cluster_id, database, secret_arn=secret_arn, db_user=db_user, **kwargs) + return RedshiftDataApi( + cluster_id, database, secret_arn=secret_arn, db_user=db_user, boto3_session=boto3_session, **kwargs + ) def read_sql_query(sql: str, con: RedshiftDataApi, database: Optional[str] = None) -> pd.DataFrame: diff --git a/tests/test_data_api.py b/tests/test_data_api.py index 48465ad1b..5f36db9ce 100644 --- a/tests/test_data_api.py +++ b/tests/test_data_api.py @@ -1,3 +1,4 @@ +import boto3 import pandas as pd import pytest @@ -11,7 +12,7 @@ def redshift_connector(databases_parameters): cluster_id = databases_parameters["redshift"]["identifier"] database = databases_parameters["redshift"]["database"] secret_arn = databases_parameters["redshift"]["secret_arn"] - conn = wr.data_api.redshift.connect(cluster_id, database, secret_arn=secret_arn) + conn = wr.data_api.redshift.connect(cluster_id, database, secret_arn=secret_arn, boto3_session=None) return conn @@ -19,7 +20,7 @@ def create_rds_connector(rds_type, parameters): cluster_id = parameters[rds_type]["arn"] database = parameters[rds_type]["database"] secret_arn = parameters[rds_type]["secret_arn"] - conn = wr.data_api.rds.connect(cluster_id, database, secret_arn=secret_arn) + conn = wr.data_api.rds.connect(cluster_id, database, secret_arn=secret_arn, boto3_session=boto3.DEFAULT_SESSION) return conn