Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions awswrangler/_databases.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,3 +256,9 @@ def convert_value_to_native_python_type(value: Any) -> Any:
chunk_placeholders = ", ".join([f"({column_placeholders})" for _ in range(len(parameters_chunk))])
flattened_chunk = [convert_value_to_native_python_type(value) for row in parameters_chunk for value in row]
yield chunk_placeholders, flattened_chunk


def validate_mode(mode: str, allowed_modes: List[str]) -> None:
"""Check if mode is included in allowed_modes."""
if mode not in allowed_modes:
raise exceptions.InvalidArgumentValue(f"mode must be one of {', '.join(allowed_modes)}")
2 changes: 1 addition & 1 deletion awswrangler/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,7 +227,7 @@ def chunkify(lst: List[Any], num_chunks: int = 1, max_length: Optional[int] = No
if not lst:
return []
n: int = num_chunks if max_length is None else int(math.ceil((float(len(lst)) / float(max_length))))
np_chunks = np.array_split(lst, n) # type: ignore
np_chunks = np.array_split(lst, n)
return [arr.tolist() for arr in np_chunks if len(arr) > 0]


Expand Down
11 changes: 5 additions & 6 deletions awswrangler/mysql.py
Original file line number Diff line number Diff line change
Expand Up @@ -292,8 +292,8 @@ def to_sql(
Schema name
mode : str
Append, overwrite, upsert_duplicate_key, upsert_replace_into, upsert_distinct.
append: Inserts new records into table
overwrite: Drops table and recreates
append: Inserts new records into table.
overwrite: Drops table and recreates.
upsert_duplicate_key: Performs an upsert using `ON DUPLICATE KEY` clause. Requires table schema to have
defined keys, otherwise duplicate records will be inserted.
upsert_replace_into: Performs upsert using `REPLACE INTO` clause. Less efficient and still requires the
Expand Down Expand Up @@ -340,17 +340,16 @@ def to_sql(
"""
if df.empty is True:
raise exceptions.EmptyDataFrame()

mode = mode.strip().lower()
modes = [
allowed_modes = [
"append",
"overwrite",
"upsert_replace_into",
"upsert_duplicate_key",
"upsert_distinct",
]
if mode not in modes:
raise exceptions.InvalidArgumentValue(f"mode must be one of {', '.join(modes)}")

_db_utils.validate_mode(mode=mode, allowed_modes=allowed_modes)
_validate_connection(con=con)
try:
with con.cursor() as cursor:
Expand Down
23 changes: 21 additions & 2 deletions awswrangler/postgresql.py
Original file line number Diff line number Diff line change
Expand Up @@ -277,6 +277,7 @@ def to_sql(
varchar_lengths: Optional[Dict[str, int]] = None,
use_column_names: bool = False,
chunksize: int = 200,
upsert_conflict_columns: Optional[List[str]] = None,
) -> None:
"""Write records stored in a DataFrame into PostgreSQL.

Expand All @@ -291,7 +292,11 @@ def to_sql(
schema : str
Schema name
mode : str
Append or overwrite.
Append, overwrite or upsert.
append: Inserts new records into table.
overwrite: Drops table and recreates.
upsert: Perform an upsert which checks for conflicts on columns given by `upsert_conflict_columns` and
sets the new values on conflicts. Note that `upsert_conflict_columns` is required for this mode.
index : bool
True to store the DataFrame index as a column in the table,
otherwise False to ignore it.
Expand All @@ -307,6 +312,9 @@ def to_sql(
inserted into the database columns `col1` and `col3`.
chunksize: int
Number of rows which are inserted with each SQL query. Defaults to inserting 200 rows per query.
upsert_conflict_columns: List[str], optional
This parameter is only supported if `mode` is set top `upsert`. In this case conflicts for the given columns are
checked for evaluating the upsert.

Returns
-------
Expand All @@ -330,6 +338,12 @@ def to_sql(
"""
if df.empty is True:
raise exceptions.EmptyDataFrame()

mode = mode.strip().lower()
allowed_modes = ["append", "overwrite", "upsert"]
_db_utils.validate_mode(mode=mode, allowed_modes=allowed_modes)
if mode == "upsert" and not upsert_conflict_columns:
raise exceptions.InvalidArgumentValue("<upsert_conflict_columns> needs to be set when using upsert mode.")
_validate_connection(con=con)
try:
with con.cursor() as cursor:
Expand All @@ -347,13 +361,18 @@ def to_sql(
df.reset_index(level=df.index.names, inplace=True)
column_placeholders: str = ", ".join(["%s"] * len(df.columns))
insertion_columns = ""
upsert_str = ""
if use_column_names:
insertion_columns = f"({', '.join(df.columns)})"
if mode == "upsert":
upsert_columns = ", ".join(df.columns.map(lambda column: f"{column}=EXCLUDED.{column}"))
conflict_columns = ", ".join(upsert_conflict_columns) # type: ignore
upsert_str = f" ON CONFLICT ({conflict_columns}) DO UPDATE SET {upsert_columns}"
placeholder_parameter_pair_generator = _db_utils.generate_placeholder_parameter_pairs(
df=df, column_placeholders=column_placeholders, chunksize=chunksize
)
for placeholders, parameters in placeholder_parameter_pair_generator:
sql: str = f'INSERT INTO "{schema}"."{table}" {insertion_columns} VALUES {placeholders}'
sql: str = f'INSERT INTO "{schema}"."{table}" {insertion_columns} VALUES {placeholders}{upsert_str}'
_logger.debug("sql: %s", sql)
cursor.executemany(sql, (parameters,))
con.commit()
Expand Down
152 changes: 152 additions & 0 deletions tests/test_postgresql.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,3 +219,155 @@ def test_dfs_are_equal_for_different_chunksizes(postgresql_table, postgresql_con
df["c1"] = df["c1"].astype("string")

assert df.equals(df2)


def test_upsert(postgresql_table, postgresql_con):
create_table_sql = (
f"CREATE TABLE public.{postgresql_table} "
"(c0 varchar NULL PRIMARY KEY,"
"c1 int NULL DEFAULT 42,"
"c2 int NOT NULL);"
)
with postgresql_con.cursor() as cursor:
cursor.execute(create_table_sql)
postgresql_con.commit()

df = pd.DataFrame({"c0": ["foo", "bar"], "c2": [1, 2]})

with pytest.raises(wr.exceptions.InvalidArgumentValue):
wr.postgresql.to_sql(
df=df,
con=postgresql_con,
schema="public",
table=postgresql_table,
mode="upsert",
upsert_conflict_columns=None,
use_column_names=True,
)

wr.postgresql.to_sql(
df=df,
con=postgresql_con,
schema="public",
table=postgresql_table,
mode="upsert",
upsert_conflict_columns=["c0"],
use_column_names=True,
)
wr.postgresql.to_sql(
df=df,
con=postgresql_con,
schema="public",
table=postgresql_table,
mode="upsert",
upsert_conflict_columns=["c0"],
use_column_names=True,
)
df2 = wr.postgresql.read_sql_table(con=postgresql_con, schema="public", table=postgresql_table)
assert bool(len(df2) == 2)

wr.postgresql.to_sql(
df=df,
con=postgresql_con,
schema="public",
table=postgresql_table,
mode="upsert",
upsert_conflict_columns=["c0"],
use_column_names=True,
)
df3 = pd.DataFrame({"c0": ["baz", "bar"], "c2": [3, 2]})
wr.postgresql.to_sql(
df=df3,
con=postgresql_con,
schema="public",
table=postgresql_table,
mode="upsert",
upsert_conflict_columns=["c0"],
use_column_names=True,
)
df4 = wr.postgresql.read_sql_table(con=postgresql_con, schema="public", table=postgresql_table)
assert bool(len(df4) == 3)

df5 = pd.DataFrame({"c0": ["foo", "bar"], "c2": [4, 5]})
wr.postgresql.to_sql(
df=df5,
con=postgresql_con,
schema="public",
table=postgresql_table,
mode="upsert",
upsert_conflict_columns=["c0"],
use_column_names=True,
)

df6 = wr.postgresql.read_sql_table(con=postgresql_con, schema="public", table=postgresql_table)
assert bool(len(df6) == 3)
assert bool(len(df6.loc[(df6["c0"] == "foo") & (df6["c2"] == 4)]) == 1)
assert bool(len(df6.loc[(df6["c0"] == "bar") & (df6["c2"] == 5)]) == 1)


def test_upsert_multiple_conflict_columns(postgresql_table, postgresql_con):
create_table_sql = (
f"CREATE TABLE public.{postgresql_table} "
"(c0 varchar NULL PRIMARY KEY,"
"c1 int NOT NULL,"
"c2 int NOT NULL,"
"UNIQUE (c1, c2));"
)
with postgresql_con.cursor() as cursor:
cursor.execute(create_table_sql)
postgresql_con.commit()

df = pd.DataFrame({"c0": ["foo", "bar"], "c1": [1, 2], "c2": [3, 4]})
upsert_conflict_columns = ["c1", "c2"]

wr.postgresql.to_sql(
df=df,
con=postgresql_con,
schema="public",
table=postgresql_table,
mode="upsert",
upsert_conflict_columns=upsert_conflict_columns,
use_column_names=True,
)
wr.postgresql.to_sql(
df=df,
con=postgresql_con,
schema="public",
table=postgresql_table,
mode="upsert",
upsert_conflict_columns=upsert_conflict_columns,
use_column_names=True,
)
df2 = wr.postgresql.read_sql_table(con=postgresql_con, schema="public", table=postgresql_table)
assert bool(len(df2) == 2)

df3 = pd.DataFrame({"c0": ["baz", "spam"], "c1": [1, 5], "c2": [3, 2]})
wr.postgresql.to_sql(
df=df3,
con=postgresql_con,
schema="public",
table=postgresql_table,
mode="upsert",
upsert_conflict_columns=upsert_conflict_columns,
use_column_names=True,
)
df4 = wr.postgresql.read_sql_table(con=postgresql_con, schema="public", table=postgresql_table)
assert bool(len(df4) == 3)

df5 = pd.DataFrame({"c0": ["egg", "spam"], "c1": [2, 5], "c2": [4, 2]})
wr.postgresql.to_sql(
df=df5,
con=postgresql_con,
schema="public",
table=postgresql_table,
mode="upsert",
upsert_conflict_columns=upsert_conflict_columns,
use_column_names=True,
)

df6 = wr.postgresql.read_sql_table(con=postgresql_con, schema="public", table=postgresql_table)
df7 = pd.DataFrame({"c0": ["baz", "egg", "spam"], "c1": [1, 2, 5], "c2": [3, 4, 2]})
df7["c0"] = df7["c0"].astype("string")
df7["c1"] = df7["c1"].astype("Int64")
df7["c2"] = df7["c2"].astype("Int64")
assert df6.equals(df7)