diff --git a/awswrangler/redshift.py b/awswrangler/redshift.py index 0d3a08f34..17ef3f1b0 100644 --- a/awswrangler/redshift.py +++ b/awswrangler/redshift.py @@ -685,7 +685,7 @@ def read_sql_table( @apply_configs -def to_sql( +def to_sql( # pylint: disable=too-many-locals df: pd.DataFrame, con: redshift_connector.Connection, table: str, @@ -704,6 +704,7 @@ def to_sql( use_column_names: bool = False, lock: bool = False, chunksize: int = 200, + commit_transaction: bool = True, ) -> None: """Write records stored in a DataFrame into Redshift. @@ -764,6 +765,8 @@ def to_sql( True to execute LOCK command inside the transaction to force serializable isolation. chunksize: int Number of rows which are inserted with each SQL query. Defaults to inserting 200 rows per query. + commit_transaction: bool + Whether to commit the transaction. True by default. Returns ------- @@ -829,7 +832,8 @@ def to_sql( if lock: _lock(cursor, [table], schema=schema) _upsert(cursor=cursor, schema=schema, table=table, temp_table=created_table, primary_keys=primary_keys) - con.commit() + if commit_transaction: + con.commit() except Exception as ex: con.rollback() _logger.error(ex) @@ -1139,6 +1143,7 @@ def copy_from_files( # pylint: disable=too-many-locals,too-many-arguments path_ignore_suffix: Optional[str] = None, use_threads: bool = True, lock: bool = False, + commit_transaction: bool = True, boto3_session: Optional[boto3.Session] = None, s3_additional_kwargs: Optional[Dict[str, str]] = None, ) -> None: @@ -1227,6 +1232,8 @@ def copy_from_files( # pylint: disable=too-many-locals,too-many-arguments If enabled os.cpu_count() will be used as the max number of threads. lock : bool True to execute LOCK command inside the transaction to force serializable isolation. + commit_transaction: bool + Whether to commit the transaction. True by default. boto3_session : boto3.Session(), optional Boto3 Session. The default boto3 session will be used if boto3_session receive None. s3_additional_kwargs: @@ -1302,7 +1309,8 @@ def copy_from_files( # pylint: disable=too-many-locals,too-many-arguments if lock: _lock(cursor, [table], schema=schema) _upsert(cursor=cursor, schema=schema, table=table, temp_table=created_table, primary_keys=primary_keys) - con.commit() + if commit_transaction: + con.commit() except Exception as ex: con.rollback() _logger.error(ex) diff --git a/tests/test_redshift.py b/tests/test_redshift.py index c9f1fb88e..845f20fbe 100644 --- a/tests/test_redshift.py +++ b/tests/test_redshift.py @@ -914,3 +914,34 @@ def test_dfs_are_equal_for_different_chunksizes(redshift_table, redshift_con, ch df["c1"] = df["c1"].astype("string") assert df.equals(df2) + + +def test_to_sql_multi_transaction(redshift_table, redshift_con): + df = pd.DataFrame({"id": list((range(10))), "val": list(["foo" if i % 2 == 0 else "boo" for i in range(10)])}) + df2 = pd.DataFrame({"id": list((range(10, 15))), "val": list(["foo" if i % 2 == 0 else "boo" for i in range(5)])}) + + wr.redshift.to_sql( + df=df, + con=redshift_con, + schema="public", + table=redshift_table, + mode="overwrite", + index=False, + primary_keys=["id"], + commit_transaction=False, # Not committing + ) + + wr.redshift.to_sql( + df=df2, + con=redshift_con, + schema="public", + table=redshift_table, + mode="upsert", + index=False, + primary_keys=["id"], + commit_transaction=False, # Not committing + ) + redshift_con.commit() + df3 = wr.redshift.read_sql_query(sql=f"SELECT * FROM public.{redshift_table} ORDER BY id", con=redshift_con) + assert len(df.index) + len(df2.index) == len(df3.index) + assert len(df.columns) == len(df3.columns)