diff --git a/awswrangler/redshift.py b/awswrangler/redshift.py index 54f9d6d5f..1e40e1acc 100644 --- a/awswrangler/redshift.py +++ b/awswrangler/redshift.py @@ -33,17 +33,75 @@ def __init__(self, session): self._session = session @staticmethod - def generate_connection(database, host, port, user, password): - conn = pg8000.connect( - database=database, - host=host, - port=int(port), - user=user, - password=password, - ssl=False, - ) + def _validate_connection(database, + host, + port, + user, + password, + tcp_keepalive=True, + application_name="aws-data-wrangler-validation", + validation_timeout=5): + try: + conn = pg8000.connect(database=database, + host=host, + port=int(port), + user=user, + password=password, + ssl=True, + application_name=application_name, + tcp_keepalive=tcp_keepalive, + timeout=validation_timeout) + conn.close() + except pg8000.core.InterfaceError as e: + raise e + + @staticmethod + def generate_connection(database, + host, + port, + user, + password, + tcp_keepalive=True, + application_name="aws-data-wrangler", + connection_timeout=1_200_000, + statement_timeout=1_200_000, + validation_timeout=5): + """ + Generates a valid connection object to be passed to the load_table method + + :param database: The name of the database instance to connect with. + :param host: The hostname of the Redshift server to connect with. + :param port: The TCP/IP port of the Redshift server instance. + :param user: The username to connect to the Redshift server with. + :param password: The user password to connect to the server with. + :param tcp_keepalive: If True then use TCP keepalive + :param application_name: Application name + :param connection_timeout: Connection Timeout + :param statement_timeout: Redshift statements timeout + :param validation_timeout: Timeout to try to validate the connection + :return: pg8000 connection + """ + Redshift._validate_connection(database=database, + host=host, + port=port, + user=user, + password=password, + tcp_keepalive=tcp_keepalive, + application_name=application_name, + validation_timeout=validation_timeout) + if isinstance(type(port), str) or isinstance(type(port), float): + port = int(port) + conn = pg8000.connect(database=database, + host=host, + port=int(port), + user=user, + password=password, + ssl=True, + application_name=application_name, + tcp_keepalive=tcp_keepalive, + timeout=connection_timeout) cursor = conn.cursor() - cursor.execute("set statement_timeout = 1200000") + cursor.execute(f"set statement_timeout = {statement_timeout}") conn.commit() cursor.close() return conn diff --git a/testing/test_awswrangler/test_redshift.py b/testing/test_awswrangler/test_redshift.py index f54a00420..679ebd049 100644 --- a/testing/test_awswrangler/test_redshift.py +++ b/testing/test_awswrangler/test_redshift.py @@ -5,6 +5,7 @@ import boto3 import pandas from pyspark.sql import SparkSession +import pg8000 from awswrangler import Session, Redshift from awswrangler.exceptions import InvalidRedshiftDiststyle, InvalidRedshiftDistkey, InvalidRedshiftSortstyle, InvalidRedshiftSortkey @@ -267,3 +268,33 @@ def test_write_load_manifest(session, bucket): assert manifest.get("entries")[0].get("url") == object_path assert manifest.get("entries")[0].get("mandatory") assert manifest.get("entries")[0].get("meta").get("content_length") == 2247 + + +def test_connection_timeout(redshift_parameters): + with pytest.raises(pg8000.core.InterfaceError): + Redshift.generate_connection( + database="test", + host=redshift_parameters.get("RedshiftAddress"), + port=12345, + user="test", + password=redshift_parameters.get("RedshiftPassword"), + ) + + +def test_connection_with_different_port_types(redshift_parameters): + conn = Redshift.generate_connection( + database="test", + host=redshift_parameters.get("RedshiftAddress"), + port=str(redshift_parameters.get("RedshiftPort")), + user="test", + password=redshift_parameters.get("RedshiftPassword"), + ) + conn.close() + conn = Redshift.generate_connection( + database="test", + host=redshift_parameters.get("RedshiftAddress"), + port=float(redshift_parameters.get("RedshiftPort")), + user="test", + password=redshift_parameters.get("RedshiftPassword"), + ) + conn.close()