diff --git a/awswrangler/redshift.py b/awswrangler/redshift.py index f3e8127f9..0d3a08f34 100644 --- a/awswrangler/redshift.py +++ b/awswrangler/redshift.py @@ -1,4 +1,5 @@ """Amazon Redshift Module.""" +# pylint: disable=too-many-lines import logging import uuid @@ -30,13 +31,34 @@ def _validate_connection(con: redshift_connector.Connection) -> None: ) -def _drop_table(cursor: redshift_connector.Cursor, schema: Optional[str], table: str) -> None: +def _begin_transaction(cursor: redshift_connector.Cursor) -> None: + sql = "BEGIN TRANSACTION" + _logger.debug("Begin transaction query:\n%s", sql) + cursor.execute(sql) + + +def _drop_table(cursor: redshift_connector.Cursor, schema: Optional[str], table: str, cascade: bool = False) -> None: schema_str = f'"{schema}".' if schema else "" - sql = f'DROP TABLE IF EXISTS {schema_str}"{table}"' + cascade_str = " CASCADE" if cascade else "" + sql = f'DROP TABLE IF EXISTS {schema_str}"{table}"' f"{cascade_str}" _logger.debug("Drop table query:\n%s", sql) cursor.execute(sql) +def _truncate_table(cursor: redshift_connector.Cursor, schema: Optional[str], table: str) -> None: + schema_str = f'"{schema}".' if schema else "" + sql = f'TRUNCATE TABLE {schema_str}"{table}"' + _logger.debug("Truncate table query:\n%s", sql) + cursor.execute(sql) + + +def _delete_all(cursor: redshift_connector.Cursor, schema: Optional[str], table: str) -> None: + schema_str = f'"{schema}".' if schema else "" + sql = f'DELETE FROM {schema_str}"{table}"' + _logger.debug("Delete query:\n%s", sql) + cursor.execute(sql) + + def _get_primary_keys(cursor: redshift_connector.Cursor, schema: str, table: str) -> List[str]: cursor.execute(f"SELECT indexdef FROM pg_indexes WHERE schemaname = '{schema}' AND tablename = '{table}'") result: str = cursor.fetchall()[0][0] @@ -214,13 +236,15 @@ def _redshift_types_from_path( return redshift_types -def _create_table( +def _create_table( # pylint: disable=too-many-locals,too-many-arguments df: Optional[pd.DataFrame], path: Optional[Union[str, List[str]]], + con: redshift_connector.Connection, cursor: redshift_connector.Cursor, table: str, schema: str, mode: str, + overwrite_method: str, index: bool, dtype: Optional[Dict[str, str]], diststyle: str, @@ -238,7 +262,25 @@ def _create_table( s3_additional_kwargs: Optional[Dict[str, str]] = None, ) -> Tuple[str, Optional[str]]: if mode == "overwrite": - _drop_table(cursor=cursor, schema=schema, table=table) + if overwrite_method == "truncate": + try: + # Truncate commits current transaction, if successful. + # Fast, but not atomic. + _truncate_table(cursor=cursor, schema=schema, table=table) + except redshift_connector.error.ProgrammingError as e: + # Caught "relation does not exist". + if e.args[0]["C"] != "42P01": # pylint: disable=invalid-sequence-index + raise e + _logger.debug(str(e)) + con.rollback() + _begin_transaction(cursor=cursor) + elif overwrite_method == "delete": + if _does_table_exist(cursor=cursor, schema=schema, table=table): + # Atomic, but slow. + _delete_all(cursor=cursor, schema=schema, table=table) + else: + # Fast, atomic, but either fails if there are any dependent views or, in cascade mode, deletes them. + _drop_table(cursor=cursor, schema=schema, table=table, cascade=bool(overwrite_method == "cascade")) elif _does_table_exist(cursor=cursor, schema=schema, table=table) is True: if mode == "upsert": guid: str = uuid.uuid4().hex @@ -649,6 +691,7 @@ def to_sql( table: str, schema: str, mode: str = "append", + overwrite_method: str = "drop", index: bool = False, dtype: Optional[Dict[str, str]] = None, diststyle: str = "AUTO", @@ -682,6 +725,14 @@ def to_sql( Schema name mode : str Append, overwrite or upsert. + overwrite_method : str + Drop, cascade, truncate, or delete. Only applicable in overwrite mode. + + "drop" - ``DROP ... RESTRICT`` - drops the table. Fails if there are any views that depend on it. + "cascade" - ``DROP ... CASCADE`` - drops the table, and all views that depend on it. + "truncate" - ``TRUNCATE ...`` - truncates the table, but immediatly commits current + transaction & starts a new one, hence the overwrite happens in two transactions and is not atomic. + "delete" - ``DELETE FROM ...`` - deletes all rows from the table. Slow relative to the other methods. index : bool True to store the DataFrame index as a column in the table, otherwise False to ignore it. @@ -744,10 +795,12 @@ def to_sql( created_table, created_schema = _create_table( df=df, path=None, + con=con, cursor=cursor, table=table, schema=schema, mode=mode, + overwrite_method=overwrite_method, index=index, dtype=dtype, diststyle=diststyle, @@ -1073,6 +1126,7 @@ def copy_from_files( # pylint: disable=too-many-locals,too-many-arguments aws_session_token: Optional[str] = None, parquet_infer_sampling: float = 1.0, mode: str = "append", + overwrite_method: str = "drop", diststyle: str = "AUTO", distkey: Optional[str] = None, sortstyle: str = "COMPOUND", @@ -1130,6 +1184,14 @@ def copy_from_files( # pylint: disable=too-many-locals,too-many-arguments The lower, the faster. mode : str Append, overwrite or upsert. + overwrite_method : str + Drop, cascade, truncate, or delete. Only applicable in overwrite mode. + + "drop" - ``DROP ... RESTRICT`` - drops the table. Fails if there are any views that depend on it. + "cascade" - ``DROP ... CASCADE`` - drops the table, and all views that depend on it. + "truncate" - ``TRUNCATE ...`` - truncates the table, but immediatly commits current + transaction & starts a new one, hence the overwrite happens in two transactions and is not atomic. + "delete" - ``DELETE FROM ...`` - deletes all rows from the table. Slow relative to the other methods. diststyle : str Redshift distribution styles. Must be in ["AUTO", "EVEN", "ALL", "KEY"]. https://docs.aws.amazon.com/redshift/latest/dg/t_Distributing_data.html @@ -1202,10 +1264,12 @@ def copy_from_files( # pylint: disable=too-many-locals,too-many-arguments parquet_infer_sampling=parquet_infer_sampling, path_suffix=path_suffix, path_ignore_suffix=path_ignore_suffix, + con=con, cursor=cursor, table=table, schema=schema, mode=mode, + overwrite_method=overwrite_method, diststyle=diststyle, sortstyle=sortstyle, distkey=distkey, @@ -1260,6 +1324,7 @@ def copy( # pylint: disable=too-many-arguments index: bool = False, dtype: Optional[Dict[str, str]] = None, mode: str = "append", + overwrite_method: str = "drop", diststyle: str = "AUTO", distkey: Optional[str] = None, sortstyle: str = "COMPOUND", @@ -1327,9 +1392,17 @@ def copy( # pylint: disable=too-many-arguments Useful when you have columns with undetermined or mixed data types. Only takes effect if dataset=True. (e.g. {'col name': 'bigint', 'col2 name': 'int'}) - mode : str + mode: str Append, overwrite or upsert. - diststyle : str + overwrite_method : str + Drop, cascade, truncate, or delete. Only applicable in overwrite mode. + + "drop" - ``DROP ... RESTRICT`` - drops the table. Fails if there are any views that depend on it. + "cascade" - ``DROP ... CASCADE`` - drops the table, and all views that depend on it. + "truncate" - ``TRUNCATE ...`` - truncates the table, but immediatly commits current + transaction & starts a new one, hence the overwrite happens in two transactions and is not atomic. + "delete" - ``DELETE FROM ...`` - deletes all rows from the table. Slow relative to the other methods. + diststyle: str Redshift distribution styles. Must be in ["AUTO", "EVEN", "ALL", "KEY"]. https://docs.aws.amazon.com/redshift/latest/dg/t_Distributing_data.html distkey : str, optional @@ -1416,6 +1489,7 @@ def copy( # pylint: disable=too-many-arguments aws_secret_access_key=aws_secret_access_key, aws_session_token=aws_session_token, mode=mode, + overwrite_method=overwrite_method, diststyle=diststyle, distkey=distkey, sortstyle=sortstyle, diff --git a/tests/test_redshift.py b/tests/test_redshift.py index 180ef521a..c9f1fb88e 100644 --- a/tests/test_redshift.py +++ b/tests/test_redshift.py @@ -42,9 +42,10 @@ def test_read_sql_query_simple(databases_parameters): assert df.shape == (1, 1) -def test_to_sql_simple(redshift_table, redshift_con): +@pytest.mark.parametrize("overwrite_method", [None, "drop", "cascade", "truncate", "delete"]) +def test_to_sql_simple(redshift_table, redshift_con, overwrite_method): df = pd.DataFrame({"c0": [1, 2, 3], "c1": ["foo", "boo", "bar"]}) - wr.redshift.to_sql(df, redshift_con, redshift_table, "public", "overwrite", True) + wr.redshift.to_sql(df, redshift_con, redshift_table, "public", "overwrite", overwrite_method, True) def test_sql_types(redshift_table, redshift_con):