Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
78 changes: 68 additions & 10 deletions awswrangler/redshift.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,17 +33,75 @@ def __init__(self, session):
self._session = session

@staticmethod
def generate_connection(database, host, port, user, password):
conn = pg8000.connect(
database=database,
host=host,
port=int(port),
user=user,
password=password,
ssl=False,
)
def _validate_connection(database,
host,
port,
user,
password,
tcp_keepalive=True,
application_name="aws-data-wrangler-validation",
validation_timeout=5):
try:
conn = pg8000.connect(database=database,
host=host,
port=int(port),
user=user,
password=password,
ssl=True,
application_name=application_name,
tcp_keepalive=tcp_keepalive,
timeout=validation_timeout)
conn.close()
except pg8000.core.InterfaceError as e:
raise e

@staticmethod
def generate_connection(database,
host,
port,
user,
password,
tcp_keepalive=True,
application_name="aws-data-wrangler",
connection_timeout=1_200_000,
statement_timeout=1_200_000,
validation_timeout=5):
"""
Generates a valid connection object to be passed to the load_table method

:param database: The name of the database instance to connect with.
:param host: The hostname of the Redshift server to connect with.
:param port: The TCP/IP port of the Redshift server instance.
:param user: The username to connect to the Redshift server with.
:param password: The user password to connect to the server with.
:param tcp_keepalive: If True then use TCP keepalive
:param application_name: Application name
:param connection_timeout: Connection Timeout
:param statement_timeout: Redshift statements timeout
:param validation_timeout: Timeout to try to validate the connection
:return: pg8000 connection
"""
Redshift._validate_connection(database=database,
host=host,
port=port,
user=user,
password=password,
tcp_keepalive=tcp_keepalive,
application_name=application_name,
validation_timeout=validation_timeout)
if isinstance(type(port), str) or isinstance(type(port), float):
port = int(port)
conn = pg8000.connect(database=database,
host=host,
port=int(port),
user=user,
password=password,
ssl=True,
application_name=application_name,
tcp_keepalive=tcp_keepalive,
timeout=connection_timeout)
cursor = conn.cursor()
cursor.execute("set statement_timeout = 1200000")
cursor.execute(f"set statement_timeout = {statement_timeout}")
conn.commit()
cursor.close()
return conn
Expand Down
31 changes: 31 additions & 0 deletions testing/test_awswrangler/test_redshift.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import boto3
import pandas
from pyspark.sql import SparkSession
import pg8000

from awswrangler import Session, Redshift
from awswrangler.exceptions import InvalidRedshiftDiststyle, InvalidRedshiftDistkey, InvalidRedshiftSortstyle, InvalidRedshiftSortkey
Expand Down Expand Up @@ -267,3 +268,33 @@ def test_write_load_manifest(session, bucket):
assert manifest.get("entries")[0].get("url") == object_path
assert manifest.get("entries")[0].get("mandatory")
assert manifest.get("entries")[0].get("meta").get("content_length") == 2247


def test_connection_timeout(redshift_parameters):
with pytest.raises(pg8000.core.InterfaceError):
Redshift.generate_connection(
database="test",
host=redshift_parameters.get("RedshiftAddress"),
port=12345,
user="test",
password=redshift_parameters.get("RedshiftPassword"),
)


def test_connection_with_different_port_types(redshift_parameters):
conn = Redshift.generate_connection(
database="test",
host=redshift_parameters.get("RedshiftAddress"),
port=str(redshift_parameters.get("RedshiftPort")),
user="test",
password=redshift_parameters.get("RedshiftPassword"),
)
conn.close()
conn = Redshift.generate_connection(
database="test",
host=redshift_parameters.get("RedshiftAddress"),
port=float(redshift_parameters.get("RedshiftPort")),
user="test",
password=redshift_parameters.get("RedshiftPassword"),
)
conn.close()