From 033b9728ee304ea0f639240f55677e92d3829ea3 Mon Sep 17 00:00:00 2001 From: Maximilian Speicher Date: Mon, 12 Jul 2021 15:21:46 +0200 Subject: [PATCH 1/2] Add postgres upsert --- awswrangler/_databases.py | 6 ++ awswrangler/_utils.py | 2 +- awswrangler/mysql.py | 11 ++- awswrangler/postgresql.py | 23 ++++++- tests/test_postgresql.py | 141 ++++++++++++++++++++++++++++++++++++++ 5 files changed, 174 insertions(+), 9 deletions(-) diff --git a/awswrangler/_databases.py b/awswrangler/_databases.py index 8ae459744..fb7904c78 100644 --- a/awswrangler/_databases.py +++ b/awswrangler/_databases.py @@ -256,3 +256,9 @@ def convert_value_to_native_python_type(value: Any) -> Any: chunk_placeholders = ", ".join([f"({column_placeholders})" for _ in range(len(parameters_chunk))]) flattened_chunk = [convert_value_to_native_python_type(value) for row in parameters_chunk for value in row] yield chunk_placeholders, flattened_chunk + + +def validate_mode(mode: str, allowed_modes: List[str]) -> None: + """Check if mode is included in allowed_modes.""" + if mode not in allowed_modes: + raise exceptions.InvalidArgumentValue(f"mode must be one of {', '.join(allowed_modes)}") diff --git a/awswrangler/_utils.py b/awswrangler/_utils.py index 9b2423c5b..94ab86d3b 100644 --- a/awswrangler/_utils.py +++ b/awswrangler/_utils.py @@ -227,7 +227,7 @@ def chunkify(lst: List[Any], num_chunks: int = 1, max_length: Optional[int] = No if not lst: return [] n: int = num_chunks if max_length is None else int(math.ceil((float(len(lst)) / float(max_length)))) - np_chunks = np.array_split(lst, n) # type: ignore + np_chunks = np.array_split(lst, n) return [arr.tolist() for arr in np_chunks if len(arr) > 0] diff --git a/awswrangler/mysql.py b/awswrangler/mysql.py index 573fe95fa..289375cb4 100644 --- a/awswrangler/mysql.py +++ b/awswrangler/mysql.py @@ -292,8 +292,8 @@ def to_sql( Schema name mode : str Append, overwrite, upsert_duplicate_key, upsert_replace_into, upsert_distinct. - append: Inserts new records into table - overwrite: Drops table and recreates + append: Inserts new records into table. + overwrite: Drops table and recreates. upsert_duplicate_key: Performs an upsert using `ON DUPLICATE KEY` clause. Requires table schema to have defined keys, otherwise duplicate records will be inserted. upsert_replace_into: Performs upsert using `REPLACE INTO` clause. Less efficient and still requires the @@ -340,17 +340,16 @@ def to_sql( """ if df.empty is True: raise exceptions.EmptyDataFrame() + mode = mode.strip().lower() - modes = [ + allowed_modes = [ "append", "overwrite", "upsert_replace_into", "upsert_duplicate_key", "upsert_distinct", ] - if mode not in modes: - raise exceptions.InvalidArgumentValue(f"mode must be one of {', '.join(modes)}") - + _db_utils.validate_mode(mode=mode, allowed_modes=allowed_modes) _validate_connection(con=con) try: with con.cursor() as cursor: diff --git a/awswrangler/postgresql.py b/awswrangler/postgresql.py index 151eca61a..bc51ece20 100644 --- a/awswrangler/postgresql.py +++ b/awswrangler/postgresql.py @@ -277,6 +277,7 @@ def to_sql( varchar_lengths: Optional[Dict[str, int]] = None, use_column_names: bool = False, chunksize: int = 200, + upsert_conflict_columns: Optional[List[str]] = None, ) -> None: """Write records stored in a DataFrame into PostgreSQL. @@ -291,7 +292,11 @@ def to_sql( schema : str Schema name mode : str - Append or overwrite. + Append, overwrite or upsert. + append: Inserts new records into table. + overwrite: Drops table and recreates. + upsert: Perform an upsert which checks for conflicts on columns given by `upsert_conflict_columns` and + sets the new values on conflicts. Note that `upsert_conflict_columns` is required for this mode. index : bool True to store the DataFrame index as a column in the table, otherwise False to ignore it. @@ -307,6 +312,9 @@ def to_sql( inserted into the database columns `col1` and `col3`. chunksize: int Number of rows which are inserted with each SQL query. Defaults to inserting 200 rows per query. + upsert_conflict_columns: List[str], optional + This parameter is only supported if `mode` is set top `upsert`. In this case conflicts for the given columns are + checked for evaluating the upsert. Returns ------- @@ -330,6 +338,12 @@ def to_sql( """ if df.empty is True: raise exceptions.EmptyDataFrame() + + mode = mode.strip().lower() + allowed_modes = ["append", "overwrite", "upsert"] + _db_utils.validate_mode(mode=mode, allowed_modes=allowed_modes) + if mode == "upsert" and not upsert_conflict_columns: + raise exceptions.InvalidArgumentValue(" needs to be set when using upsert mode.") _validate_connection(con=con) try: with con.cursor() as cursor: @@ -347,13 +361,18 @@ def to_sql( df.reset_index(level=df.index.names, inplace=True) column_placeholders: str = ", ".join(["%s"] * len(df.columns)) insertion_columns = "" + upsert_str = "" if use_column_names: insertion_columns = f"({', '.join(df.columns)})" + if mode == "upsert": + upsert_columns = ", ".join(df.columns.map(lambda column: f"{column}=EXCLUDED.{column}")) + conflict_columns = ", ".join(upsert_conflict_columns) # type: ignore + upsert_str = f" ON CONFLICT ({conflict_columns}) DO UPDATE SET {upsert_columns}" placeholder_parameter_pair_generator = _db_utils.generate_placeholder_parameter_pairs( df=df, column_placeholders=column_placeholders, chunksize=chunksize ) for placeholders, parameters in placeholder_parameter_pair_generator: - sql: str = f'INSERT INTO "{schema}"."{table}" {insertion_columns} VALUES {placeholders}' + sql: str = f'INSERT INTO "{schema}"."{table}" {insertion_columns} VALUES {placeholders}{upsert_str}' _logger.debug("sql: %s", sql) cursor.executemany(sql, (parameters,)) con.commit() diff --git a/tests/test_postgresql.py b/tests/test_postgresql.py index 4ee0dcb89..1ac10ca81 100644 --- a/tests/test_postgresql.py +++ b/tests/test_postgresql.py @@ -219,3 +219,144 @@ def test_dfs_are_equal_for_different_chunksizes(postgresql_table, postgresql_con df["c1"] = df["c1"].astype("string") assert df.equals(df2) + + +def test_upsert(postgresql_table, postgresql_con): + create_table_sql = ( + f"CREATE TABLE public.{postgresql_table} " + "(c0 varchar NULL PRIMARY KEY," + "c1 int NULL DEFAULT 42," + "c2 int NOT NULL);" + ) + with postgresql_con.cursor() as cursor: + cursor.execute(create_table_sql) + postgresql_con.commit() + + df = pd.DataFrame({"c0": ["foo", "bar"], "c2": [1, 2]}) + + wr.postgresql.to_sql( + df=df, + con=postgresql_con, + schema="public", + table=postgresql_table, + mode="upsert", + upsert_conflict_columns=["c0"], + use_column_names=True, + ) + wr.postgresql.to_sql( + df=df, + con=postgresql_con, + schema="public", + table=postgresql_table, + mode="upsert", + upsert_conflict_columns=["c0"], + use_column_names=True, + ) + df2 = wr.postgresql.read_sql_table(con=postgresql_con, schema="public", table=postgresql_table) + assert bool(len(df2) == 2) + + wr.postgresql.to_sql( + df=df, + con=postgresql_con, + schema="public", + table=postgresql_table, + mode="upsert", + upsert_conflict_columns=["c0"], + use_column_names=True, + ) + df3 = pd.DataFrame({"c0": ["baz", "bar"], "c2": [3, 2]}) + wr.postgresql.to_sql( + df=df3, + con=postgresql_con, + schema="public", + table=postgresql_table, + mode="upsert", + upsert_conflict_columns=["c0"], + use_column_names=True, + ) + df4 = wr.postgresql.read_sql_table(con=postgresql_con, schema="public", table=postgresql_table) + assert bool(len(df4) == 3) + + df5 = pd.DataFrame({"c0": ["foo", "bar"], "c2": [4, 5]}) + wr.postgresql.to_sql( + df=df5, + con=postgresql_con, + schema="public", + table=postgresql_table, + mode="upsert", + upsert_conflict_columns=["c0"], + use_column_names=True, + ) + + df6 = wr.postgresql.read_sql_table(con=postgresql_con, schema="public", table=postgresql_table) + assert bool(len(df6) == 3) + assert bool(len(df6.loc[(df6["c0"] == "foo") & (df6["c2"] == 4)]) == 1) + assert bool(len(df6.loc[(df6["c0"] == "bar") & (df6["c2"] == 5)]) == 1) + + +def test_upsert_multiple_conflict_columns(postgresql_table, postgresql_con): + create_table_sql = ( + f"CREATE TABLE public.{postgresql_table} " + "(c0 varchar NULL PRIMARY KEY," + "c1 int NOT NULL," + "c2 int NOT NULL," + "UNIQUE (c1, c2));" + ) + with postgresql_con.cursor() as cursor: + cursor.execute(create_table_sql) + postgresql_con.commit() + + df = pd.DataFrame({"c0": ["foo", "bar"], "c1": [1, 2], "c2": [3, 4]}) + upsert_conflict_columns = ["c1", "c2"] + + wr.postgresql.to_sql( + df=df, + con=postgresql_con, + schema="public", + table=postgresql_table, + mode="upsert", + upsert_conflict_columns=upsert_conflict_columns, + use_column_names=True, + ) + wr.postgresql.to_sql( + df=df, + con=postgresql_con, + schema="public", + table=postgresql_table, + mode="upsert", + upsert_conflict_columns=upsert_conflict_columns, + use_column_names=True, + ) + df2 = wr.postgresql.read_sql_table(con=postgresql_con, schema="public", table=postgresql_table) + assert bool(len(df2) == 2) + + df3 = pd.DataFrame({"c0": ["baz", "spam"], "c1": [1, 5], "c2": [3, 2]}) + wr.postgresql.to_sql( + df=df3, + con=postgresql_con, + schema="public", + table=postgresql_table, + mode="upsert", + upsert_conflict_columns=upsert_conflict_columns, + use_column_names=True, + ) + df4 = wr.postgresql.read_sql_table(con=postgresql_con, schema="public", table=postgresql_table) + assert bool(len(df4) == 3) + + df5 = pd.DataFrame({"c0": ["egg", "spam"], "c1": [2, 5], "c2": [4, 2]}) + wr.postgresql.to_sql( + df=df5, + con=postgresql_con, + schema="public", + table=postgresql_table, + mode="upsert", + upsert_conflict_columns=upsert_conflict_columns, + use_column_names=True, + ) + + df6 = wr.postgresql.read_sql_table(con=postgresql_con, schema="public", table=postgresql_table) + df7 = pd.DataFrame({"c0": ["baz", "egg", "spam"], "c1": [1, 2, 5], "c2": [3, 4, 2]}) + df7["c0"] = df7["c0"].astype("string") + df7["c1"] = df7["c1"].astype("Int64") + df7["c2"] = df7["c2"].astype("Int64") + assert df6.equals(df7) From 475710c51de1defe8c7841205268bd3a67a348b2 Mon Sep 17 00:00:00 2001 From: Abdel Jaidi Date: Fri, 16 Jul 2021 10:54:22 +0100 Subject: [PATCH 2/2] Minor - Raising exception when no conflict cols --- tests/test_postgresql.py | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/tests/test_postgresql.py b/tests/test_postgresql.py index 1ac10ca81..6478cc09f 100644 --- a/tests/test_postgresql.py +++ b/tests/test_postgresql.py @@ -234,6 +234,17 @@ def test_upsert(postgresql_table, postgresql_con): df = pd.DataFrame({"c0": ["foo", "bar"], "c2": [1, 2]}) + with pytest.raises(wr.exceptions.InvalidArgumentValue): + wr.postgresql.to_sql( + df=df, + con=postgresql_con, + schema="public", + table=postgresql_table, + mode="upsert", + upsert_conflict_columns=None, + use_column_names=True, + ) + wr.postgresql.to_sql( df=df, con=postgresql_con,