From 2751457af89638e198e3ab97e226f6d53279a8dc Mon Sep 17 00:00:00 2001 From: Abdel Jaidi Date: Tue, 25 May 2021 18:10:53 +0100 Subject: [PATCH 1/2] Adding flag to skip Redshift transaction commit --- awswrangler/redshift.py | 14 +++++++++++--- 1 file changed, 11 insertions(+), 3 deletions(-) diff --git a/awswrangler/redshift.py b/awswrangler/redshift.py index 0d3a08f34..17ef3f1b0 100644 --- a/awswrangler/redshift.py +++ b/awswrangler/redshift.py @@ -685,7 +685,7 @@ def read_sql_table( @apply_configs -def to_sql( +def to_sql( # pylint: disable=too-many-locals df: pd.DataFrame, con: redshift_connector.Connection, table: str, @@ -704,6 +704,7 @@ def to_sql( use_column_names: bool = False, lock: bool = False, chunksize: int = 200, + commit_transaction: bool = True, ) -> None: """Write records stored in a DataFrame into Redshift. @@ -764,6 +765,8 @@ def to_sql( True to execute LOCK command inside the transaction to force serializable isolation. chunksize: int Number of rows which are inserted with each SQL query. Defaults to inserting 200 rows per query. + commit_transaction: bool + Whether to commit the transaction. True by default. Returns ------- @@ -829,7 +832,8 @@ def to_sql( if lock: _lock(cursor, [table], schema=schema) _upsert(cursor=cursor, schema=schema, table=table, temp_table=created_table, primary_keys=primary_keys) - con.commit() + if commit_transaction: + con.commit() except Exception as ex: con.rollback() _logger.error(ex) @@ -1139,6 +1143,7 @@ def copy_from_files( # pylint: disable=too-many-locals,too-many-arguments path_ignore_suffix: Optional[str] = None, use_threads: bool = True, lock: bool = False, + commit_transaction: bool = True, boto3_session: Optional[boto3.Session] = None, s3_additional_kwargs: Optional[Dict[str, str]] = None, ) -> None: @@ -1227,6 +1232,8 @@ def copy_from_files( # pylint: disable=too-many-locals,too-many-arguments If enabled os.cpu_count() will be used as the max number of threads. lock : bool True to execute LOCK command inside the transaction to force serializable isolation. + commit_transaction: bool + Whether to commit the transaction. True by default. boto3_session : boto3.Session(), optional Boto3 Session. The default boto3 session will be used if boto3_session receive None. s3_additional_kwargs: @@ -1302,7 +1309,8 @@ def copy_from_files( # pylint: disable=too-many-locals,too-many-arguments if lock: _lock(cursor, [table], schema=schema) _upsert(cursor=cursor, schema=schema, table=table, temp_table=created_table, primary_keys=primary_keys) - con.commit() + if commit_transaction: + con.commit() except Exception as ex: con.rollback() _logger.error(ex) From 0d7c76df6811103e92b332945a5fa167c912c910 Mon Sep 17 00:00:00 2001 From: Abdel Jaidi Date: Thu, 27 May 2021 23:12:45 +0100 Subject: [PATCH 2/2] Adding tests --- awswrangler/cloudwatch.py | 6 +++--- awswrangler/s3/_write_dataset.py | 2 +- tests/test_redshift.py | 31 +++++++++++++++++++++++++++++++ 3 files changed, 35 insertions(+), 4 deletions(-) diff --git a/awswrangler/cloudwatch.py b/awswrangler/cloudwatch.py index 03f632087..8c92856cd 100644 --- a/awswrangler/cloudwatch.py +++ b/awswrangler/cloudwatch.py @@ -29,7 +29,7 @@ def start_query( query: str, log_group_names: List[str], start_time: datetime.datetime = datetime.datetime(year=1970, month=1, day=1, tzinfo=datetime.timezone.utc), - end_time: datetime.datetime = datetime.datetime.now(), + end_time: datetime.datetime = datetime.datetime.utcnow(), limit: Optional[int] = None, boto3_session: Optional[boto3.Session] = None, ) -> str: @@ -132,7 +132,7 @@ def run_query( query: str, log_group_names: List[str], start_time: datetime.datetime = datetime.datetime(year=1970, month=1, day=1, tzinfo=datetime.timezone.utc), - end_time: datetime.datetime = datetime.datetime.now(), + end_time: datetime.datetime = datetime.datetime.utcnow(), limit: Optional[int] = None, boto3_session: Optional[boto3.Session] = None, ) -> List[List[Dict[str, str]]]: @@ -186,7 +186,7 @@ def read_logs( query: str, log_group_names: List[str], start_time: datetime.datetime = datetime.datetime(year=1970, month=1, day=1, tzinfo=datetime.timezone.utc), - end_time: datetime.datetime = datetime.datetime.now(), + end_time: datetime.datetime = datetime.datetime.utcnow(), limit: Optional[int] = None, boto3_session: Optional[boto3.Session] = None, ) -> pd.DataFrame: diff --git a/awswrangler/s3/_write_dataset.py b/awswrangler/s3/_write_dataset.py index 977ff152c..1c37bcdae 100644 --- a/awswrangler/s3/_write_dataset.py +++ b/awswrangler/s3/_write_dataset.py @@ -81,7 +81,7 @@ def _to_buckets( **func_kwargs: Any, ) -> List[str]: _proxy: _WriteProxy = proxy if proxy else _WriteProxy(use_threads=False) - bucket_number_series = df.apply( + bucket_number_series = df.astype("O").apply( lambda row: _get_bucket_number(bucketing_info[1], [row[col_name] for col_name in bucketing_info[0]]), axis="columns", ) diff --git a/tests/test_redshift.py b/tests/test_redshift.py index c9f1fb88e..845f20fbe 100644 --- a/tests/test_redshift.py +++ b/tests/test_redshift.py @@ -914,3 +914,34 @@ def test_dfs_are_equal_for_different_chunksizes(redshift_table, redshift_con, ch df["c1"] = df["c1"].astype("string") assert df.equals(df2) + + +def test_to_sql_multi_transaction(redshift_table, redshift_con): + df = pd.DataFrame({"id": list((range(10))), "val": list(["foo" if i % 2 == 0 else "boo" for i in range(10)])}) + df2 = pd.DataFrame({"id": list((range(10, 15))), "val": list(["foo" if i % 2 == 0 else "boo" for i in range(5)])}) + + wr.redshift.to_sql( + df=df, + con=redshift_con, + schema="public", + table=redshift_table, + mode="overwrite", + index=False, + primary_keys=["id"], + commit_transaction=False, # Not committing + ) + + wr.redshift.to_sql( + df=df2, + con=redshift_con, + schema="public", + table=redshift_table, + mode="upsert", + index=False, + primary_keys=["id"], + commit_transaction=False, # Not committing + ) + redshift_con.commit() + df3 = wr.redshift.read_sql_query(sql=f"SELECT * FROM public.{redshift_table} ORDER BY id", con=redshift_con) + assert len(df.index) + len(df2.index) == len(df3.index) + assert len(df.columns) == len(df3.columns)