Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 11 additions & 3 deletions awswrangler/data_api/rds.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import boto3
import pandas as pd

from awswrangler import _utils
from awswrangler.data_api import connector


Expand All @@ -27,6 +28,8 @@ class RdsDataApi(connector.DataApiConnector):
Factor by which to increase the sleep between connection attempts to paused clusters - defaults to 1.0.
retries: int
Maximum number of connection attempts to paused clusters - defaults to 10.
boto3_session : boto3.Session(), optional
The boto3 session. If `None`, the default boto3 session is used.
"""

def __init__(
Expand All @@ -37,12 +40,13 @@ def __init__(
sleep: float = 0.5,
backoff: float = 1.0,
retries: int = 30,
boto3_session: Optional[boto3.Session] = None,
) -> None:
self.resource_arn = resource_arn
self.database = database
self.secret_arn = secret_arn
self.wait_config = connector.WaitConfig(sleep, backoff, retries)
self.client = boto3.client("rds-data")
self.client: boto3.client = _utils.client(service_name="rds-data", session=boto3_session)
self.results: Dict[str, Dict[str, Any]] = {}
logger: logging.Logger = logging.getLogger(__name__)
super().__init__(self.client, logger)
Expand Down Expand Up @@ -114,7 +118,9 @@ def _get_statement_result(self, request_id: str) -> pd.DataFrame:
return dataframe


def connect(resource_arn: str, database: str, secret_arn: str = "", **kwargs: Any) -> RdsDataApi:
def connect(
resource_arn: str, database: str, secret_arn: str = "", boto3_session: Optional[boto3.Session] = None, **kwargs: Any
) -> RdsDataApi:
"""Create a RDS Data API connection.

Parameters
Expand All @@ -125,14 +131,16 @@ def connect(resource_arn: str, database: str, secret_arn: str = "", **kwargs: An
Target database name.
secret_arn: str
The ARN for the secret to be used for authentication.
boto3_session : boto3.Session(), optional
The boto3 session. If `None`, the default boto3 session is used.
**kwargs
Any additional kwargs are passed to the underlying RdsDataApi class.

Returns
-------
A RdsDataApi connection instance that can be used with `wr.rds.data_api.read_sql_query`.
"""
return RdsDataApi(resource_arn, database, secret_arn=secret_arn, **kwargs)
return RdsDataApi(resource_arn, database, secret_arn=secret_arn, boto3_session=boto3_session, **kwargs)


def read_sql_query(sql: str, con: RdsDataApi, database: Optional[str] = None) -> pd.DataFrame:
Expand Down
21 changes: 18 additions & 3 deletions awswrangler/data_api/redshift.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import boto3
import pandas as pd

from awswrangler import _utils
from awswrangler.data_api import connector


Expand All @@ -28,6 +29,8 @@ class RedshiftDataApi(connector.DataApiConnector):
Factor by which to increase the sleep between result fetch attempts - defaults to 1.5.
retries: int
Maximum number of result fetch attempts - defaults to 15.
boto3_session : boto3.Session(), optional
The boto3 session. If `None`, the default boto3 session is used.
"""

def __init__(
Expand All @@ -39,12 +42,13 @@ def __init__(
sleep: float = 0.25,
backoff: float = 1.5,
retries: int = 15,
boto3_session: Optional[boto3.Session] = None,
) -> None:
self.cluster_id = cluster_id
self.database = database
self.secret_arn = secret_arn
self.db_user = db_user
self.client = boto3.client("redshift-data")
self.client: boto3.client = _utils.client(service_name="redshift-data", session=boto3_session)
self.waiter = RedshiftDataApiWaiter(self.client, sleep, backoff, retries)
logger: logging.Logger = logging.getLogger(__name__)
super().__init__(self.client, logger)
Expand Down Expand Up @@ -162,7 +166,14 @@ class RedshiftDataApiTimeoutException(Exception):
"""Indicates a statement execution did not complete in the expected wait time."""


def connect(cluster_id: str, database: str, secret_arn: str = "", db_user: str = "", **kwargs: Any) -> RedshiftDataApi:
def connect(
cluster_id: str,
database: str,
secret_arn: str = "",
db_user: str = "",
boto3_session: Optional[boto3.Session] = None,
**kwargs: Any,
) -> RedshiftDataApi:
"""Create a Redshift Data API connection.

Parameters
Expand All @@ -175,14 +186,18 @@ def connect(cluster_id: str, database: str, secret_arn: str = "", db_user: str =
The ARN for the secret to be used for authentication - only required if `db_user` not provided.
db_user: str
The database user to generate temporary credentials for - only required if `secret_arn` not provided.
boto3_session : boto3.Session(), optional
The boto3 session. If `None`, the default boto3 session is used.
**kwargs
Any additional kwargs are passed to the underlying RedshiftDataApi class.

Returns
-------
A RedshiftDataApi connection instance that can be used with `wr.redshift.data_api.read_sql_query`.
"""
return RedshiftDataApi(cluster_id, database, secret_arn=secret_arn, db_user=db_user, **kwargs)
return RedshiftDataApi(
cluster_id, database, secret_arn=secret_arn, db_user=db_user, boto3_session=boto3_session, **kwargs
)


def read_sql_query(sql: str, con: RedshiftDataApi, database: Optional[str] = None) -> pd.DataFrame:
Expand Down
5 changes: 3 additions & 2 deletions tests/test_data_api.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import boto3
import pandas as pd
import pytest

Expand All @@ -11,15 +12,15 @@ def redshift_connector(databases_parameters):
cluster_id = databases_parameters["redshift"]["identifier"]
database = databases_parameters["redshift"]["database"]
secret_arn = databases_parameters["redshift"]["secret_arn"]
conn = wr.data_api.redshift.connect(cluster_id, database, secret_arn=secret_arn)
conn = wr.data_api.redshift.connect(cluster_id, database, secret_arn=secret_arn, boto3_session=None)
return conn


def create_rds_connector(rds_type, parameters):
cluster_id = parameters[rds_type]["arn"]
database = parameters[rds_type]["database"]
secret_arn = parameters[rds_type]["secret_arn"]
conn = wr.data_api.rds.connect(cluster_id, database, secret_arn=secret_arn)
conn = wr.data_api.rds.connect(cluster_id, database, secret_arn=secret_arn, boto3_session=boto3.DEFAULT_SESSION)
return conn


Expand Down