Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
86 changes: 80 additions & 6 deletions awswrangler/redshift.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Amazon Redshift Module."""
# pylint: disable=too-many-lines
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Any reason why this is applied at the file level instead of the offending method?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, this is actually regarding too many lines in module:
awswrangler/redshift.py:1:0: C0302: Too many lines in module (1507/1500) (too-many-lines)


import logging
import uuid
Expand Down Expand Up @@ -30,13 +31,34 @@ def _validate_connection(con: redshift_connector.Connection) -> None:
)


def _drop_table(cursor: redshift_connector.Cursor, schema: Optional[str], table: str) -> None:
def _begin_transaction(cursor: redshift_connector.Cursor) -> None:
sql = "BEGIN TRANSACTION"
_logger.debug("Begin transaction query:\n%s", sql)
cursor.execute(sql)


def _drop_table(cursor: redshift_connector.Cursor, schema: Optional[str], table: str, cascade: bool = False) -> None:
schema_str = f'"{schema}".' if schema else ""
sql = f'DROP TABLE IF EXISTS {schema_str}"{table}"'
cascade_str = " CASCADE" if cascade else ""
sql = f'DROP TABLE IF EXISTS {schema_str}"{table}"' f"{cascade_str}"
_logger.debug("Drop table query:\n%s", sql)
cursor.execute(sql)


def _truncate_table(cursor: redshift_connector.Cursor, schema: Optional[str], table: str) -> None:
schema_str = f'"{schema}".' if schema else ""
sql = f'TRUNCATE TABLE {schema_str}"{table}"'
_logger.debug("Truncate table query:\n%s", sql)
cursor.execute(sql)


def _delete_all(cursor: redshift_connector.Cursor, schema: Optional[str], table: str) -> None:
schema_str = f'"{schema}".' if schema else ""
sql = f'DELETE FROM {schema_str}"{table}"'
_logger.debug("Delete query:\n%s", sql)
cursor.execute(sql)


def _get_primary_keys(cursor: redshift_connector.Cursor, schema: str, table: str) -> List[str]:
cursor.execute(f"SELECT indexdef FROM pg_indexes WHERE schemaname = '{schema}' AND tablename = '{table}'")
result: str = cursor.fetchall()[0][0]
Expand Down Expand Up @@ -214,13 +236,15 @@ def _redshift_types_from_path(
return redshift_types


def _create_table(
def _create_table( # pylint: disable=too-many-locals,too-many-arguments
df: Optional[pd.DataFrame],
path: Optional[Union[str, List[str]]],
con: redshift_connector.Connection,
cursor: redshift_connector.Cursor,
table: str,
schema: str,
mode: str,
overwrite_method: str,
index: bool,
dtype: Optional[Dict[str, str]],
diststyle: str,
Expand All @@ -238,7 +262,25 @@ def _create_table(
s3_additional_kwargs: Optional[Dict[str, str]] = None,
) -> Tuple[str, Optional[str]]:
if mode == "overwrite":
_drop_table(cursor=cursor, schema=schema, table=table)
if overwrite_method == "truncate":
try:
# Truncate commits current transaction, if successful.
# Fast, but not atomic.
_truncate_table(cursor=cursor, schema=schema, table=table)
except redshift_connector.error.ProgrammingError as e:
# Caught "relation does not exist".
if e.args[0]["C"] != "42P01": # pylint: disable=invalid-sequence-index
raise e
_logger.debug(str(e))
con.rollback()
_begin_transaction(cursor=cursor)
elif overwrite_method == "delete":
if _does_table_exist(cursor=cursor, schema=schema, table=table):
# Atomic, but slow.
_delete_all(cursor=cursor, schema=schema, table=table)
else:
# Fast, atomic, but either fails if there are any dependent views or, in cascade mode, deletes them.
_drop_table(cursor=cursor, schema=schema, table=table, cascade=bool(overwrite_method == "cascade"))
elif _does_table_exist(cursor=cursor, schema=schema, table=table) is True:
if mode == "upsert":
guid: str = uuid.uuid4().hex
Expand Down Expand Up @@ -649,6 +691,7 @@ def to_sql(
table: str,
schema: str,
mode: str = "append",
overwrite_method: str = "drop",
index: bool = False,
dtype: Optional[Dict[str, str]] = None,
diststyle: str = "AUTO",
Expand Down Expand Up @@ -682,6 +725,14 @@ def to_sql(
Schema name
mode : str
Append, overwrite or upsert.
overwrite_method : str
Drop, cascade, truncate, or delete. Only applicable in overwrite mode.

"drop" - ``DROP ... RESTRICT`` - drops the table. Fails if there are any views that depend on it.
"cascade" - ``DROP ... CASCADE`` - drops the table, and all views that depend on it.
"truncate" - ``TRUNCATE ...`` - truncates the table, but immediatly commits current
transaction & starts a new one, hence the overwrite happens in two transactions and is not atomic.
"delete" - ``DELETE FROM ...`` - deletes all rows from the table. Slow relative to the other methods.
index : bool
True to store the DataFrame index as a column in the table,
otherwise False to ignore it.
Expand Down Expand Up @@ -744,10 +795,12 @@ def to_sql(
created_table, created_schema = _create_table(
df=df,
path=None,
con=con,
cursor=cursor,
table=table,
schema=schema,
mode=mode,
overwrite_method=overwrite_method,
index=index,
dtype=dtype,
diststyle=diststyle,
Expand Down Expand Up @@ -1073,6 +1126,7 @@ def copy_from_files( # pylint: disable=too-many-locals,too-many-arguments
aws_session_token: Optional[str] = None,
parquet_infer_sampling: float = 1.0,
mode: str = "append",
overwrite_method: str = "drop",
diststyle: str = "AUTO",
distkey: Optional[str] = None,
sortstyle: str = "COMPOUND",
Expand Down Expand Up @@ -1130,6 +1184,14 @@ def copy_from_files( # pylint: disable=too-many-locals,too-many-arguments
The lower, the faster.
mode : str
Append, overwrite or upsert.
overwrite_method : str
Drop, cascade, truncate, or delete. Only applicable in overwrite mode.

"drop" - ``DROP ... RESTRICT`` - drops the table. Fails if there are any views that depend on it.
"cascade" - ``DROP ... CASCADE`` - drops the table, and all views that depend on it.
"truncate" - ``TRUNCATE ...`` - truncates the table, but immediatly commits current
transaction & starts a new one, hence the overwrite happens in two transactions and is not atomic.
"delete" - ``DELETE FROM ...`` - deletes all rows from the table. Slow relative to the other methods.
diststyle : str
Redshift distribution styles. Must be in ["AUTO", "EVEN", "ALL", "KEY"].
https://docs.aws.amazon.com/redshift/latest/dg/t_Distributing_data.html
Expand Down Expand Up @@ -1202,10 +1264,12 @@ def copy_from_files( # pylint: disable=too-many-locals,too-many-arguments
parquet_infer_sampling=parquet_infer_sampling,
path_suffix=path_suffix,
path_ignore_suffix=path_ignore_suffix,
con=con,
cursor=cursor,
table=table,
schema=schema,
mode=mode,
overwrite_method=overwrite_method,
diststyle=diststyle,
sortstyle=sortstyle,
distkey=distkey,
Expand Down Expand Up @@ -1260,6 +1324,7 @@ def copy( # pylint: disable=too-many-arguments
index: bool = False,
dtype: Optional[Dict[str, str]] = None,
mode: str = "append",
overwrite_method: str = "drop",
diststyle: str = "AUTO",
distkey: Optional[str] = None,
sortstyle: str = "COMPOUND",
Expand Down Expand Up @@ -1327,9 +1392,17 @@ def copy( # pylint: disable=too-many-arguments
Useful when you have columns with undetermined or mixed data types.
Only takes effect if dataset=True.
(e.g. {'col name': 'bigint', 'col2 name': 'int'})
mode : str
mode: str
Append, overwrite or upsert.
diststyle : str
overwrite_method : str
Drop, cascade, truncate, or delete. Only applicable in overwrite mode.

"drop" - ``DROP ... RESTRICT`` - drops the table. Fails if there are any views that depend on it.
"cascade" - ``DROP ... CASCADE`` - drops the table, and all views that depend on it.
"truncate" - ``TRUNCATE ...`` - truncates the table, but immediatly commits current
transaction & starts a new one, hence the overwrite happens in two transactions and is not atomic.
"delete" - ``DELETE FROM ...`` - deletes all rows from the table. Slow relative to the other methods.
diststyle: str
Redshift distribution styles. Must be in ["AUTO", "EVEN", "ALL", "KEY"].
https://docs.aws.amazon.com/redshift/latest/dg/t_Distributing_data.html
distkey : str, optional
Expand Down Expand Up @@ -1416,6 +1489,7 @@ def copy( # pylint: disable=too-many-arguments
aws_secret_access_key=aws_secret_access_key,
aws_session_token=aws_session_token,
mode=mode,
overwrite_method=overwrite_method,
diststyle=diststyle,
distkey=distkey,
sortstyle=sortstyle,
Expand Down
5 changes: 3 additions & 2 deletions tests/test_redshift.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,9 +42,10 @@ def test_read_sql_query_simple(databases_parameters):
assert df.shape == (1, 1)


def test_to_sql_simple(redshift_table, redshift_con):
@pytest.mark.parametrize("overwrite_method", [None, "drop", "cascade", "truncate", "delete"])
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks!

def test_to_sql_simple(redshift_table, redshift_con, overwrite_method):
df = pd.DataFrame({"c0": [1, 2, 3], "c1": ["foo", "boo", "bar"]})
wr.redshift.to_sql(df, redshift_con, redshift_table, "public", "overwrite", True)
wr.redshift.to_sql(df, redshift_con, redshift_table, "public", "overwrite", overwrite_method, True)


def test_sql_types(redshift_table, redshift_con):
Expand Down