Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 11 additions & 3 deletions awswrangler/redshift.py
Original file line number Diff line number Diff line change
Expand Up @@ -685,7 +685,7 @@ def read_sql_table(


@apply_configs
def to_sql(
def to_sql( # pylint: disable=too-many-locals
df: pd.DataFrame,
con: redshift_connector.Connection,
table: str,
Expand All @@ -704,6 +704,7 @@ def to_sql(
use_column_names: bool = False,
lock: bool = False,
chunksize: int = 200,
commit_transaction: bool = True,
) -> None:
"""Write records stored in a DataFrame into Redshift.

Expand Down Expand Up @@ -764,6 +765,8 @@ def to_sql(
True to execute LOCK command inside the transaction to force serializable isolation.
chunksize: int
Number of rows which are inserted with each SQL query. Defaults to inserting 200 rows per query.
commit_transaction: bool
Whether to commit the transaction. True by default.

Returns
-------
Expand Down Expand Up @@ -829,7 +832,8 @@ def to_sql(
if lock:
_lock(cursor, [table], schema=schema)
_upsert(cursor=cursor, schema=schema, table=table, temp_table=created_table, primary_keys=primary_keys)
con.commit()
if commit_transaction:
con.commit()
except Exception as ex:
con.rollback()
_logger.error(ex)
Expand Down Expand Up @@ -1139,6 +1143,7 @@ def copy_from_files( # pylint: disable=too-many-locals,too-many-arguments
path_ignore_suffix: Optional[str] = None,
use_threads: bool = True,
lock: bool = False,
commit_transaction: bool = True,
boto3_session: Optional[boto3.Session] = None,
s3_additional_kwargs: Optional[Dict[str, str]] = None,
) -> None:
Expand Down Expand Up @@ -1227,6 +1232,8 @@ def copy_from_files( # pylint: disable=too-many-locals,too-many-arguments
If enabled os.cpu_count() will be used as the max number of threads.
lock : bool
True to execute LOCK command inside the transaction to force serializable isolation.
commit_transaction: bool
Whether to commit the transaction. True by default.
boto3_session : boto3.Session(), optional
Boto3 Session. The default boto3 session will be used if boto3_session receive None.
s3_additional_kwargs:
Expand Down Expand Up @@ -1302,7 +1309,8 @@ def copy_from_files( # pylint: disable=too-many-locals,too-many-arguments
if lock:
_lock(cursor, [table], schema=schema)
_upsert(cursor=cursor, schema=schema, table=table, temp_table=created_table, primary_keys=primary_keys)
con.commit()
if commit_transaction:
con.commit()
except Exception as ex:
con.rollback()
_logger.error(ex)
Expand Down
31 changes: 31 additions & 0 deletions tests/test_redshift.py
Original file line number Diff line number Diff line change
Expand Up @@ -914,3 +914,34 @@ def test_dfs_are_equal_for_different_chunksizes(redshift_table, redshift_con, ch
df["c1"] = df["c1"].astype("string")

assert df.equals(df2)


def test_to_sql_multi_transaction(redshift_table, redshift_con):
df = pd.DataFrame({"id": list((range(10))), "val": list(["foo" if i % 2 == 0 else "boo" for i in range(10)])})
df2 = pd.DataFrame({"id": list((range(10, 15))), "val": list(["foo" if i % 2 == 0 else "boo" for i in range(5)])})

wr.redshift.to_sql(
df=df,
con=redshift_con,
schema="public",
table=redshift_table,
mode="overwrite",
index=False,
primary_keys=["id"],
commit_transaction=False, # Not committing
)

wr.redshift.to_sql(
df=df2,
con=redshift_con,
schema="public",
table=redshift_table,
mode="upsert",
index=False,
primary_keys=["id"],
commit_transaction=False, # Not committing
)
redshift_con.commit()
df3 = wr.redshift.read_sql_query(sql=f"SELECT * FROM public.{redshift_table} ORDER BY id", con=redshift_con)
assert len(df.index) + len(df2.index) == len(df3.index)
assert len(df.columns) == len(df3.columns)