Skip to content

Commit

Permalink
PostgresHook: Added ON CONFLICT DO NOTHING statement when all target …
Browse files Browse the repository at this point in the history
…fields are primary keys (#26661)

 Added ON CONFLICT DO NOTHING statement when all target fields are primary keys. Also made _generate_insert_sql use cls.placeholder for standarization sake
  • Loading branch information
alexandermalyga committed Nov 10, 2022
1 parent 59e3198 commit d7d2061
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 10 deletions.
24 changes: 14 additions & 10 deletions airflow/providers/postgres/hooks/postgres.py
Expand Up @@ -262,9 +262,9 @@ def get_table_primary_key(self, table: str, schema: str | None = "public") -> li
pk_columns = [row[0] for row in self.get_records(sql, (schema, table))]
return pk_columns or None

@staticmethod
@classmethod
def _generate_insert_sql(
table: str, values: tuple[str, ...], target_fields: Iterable[str], replace: bool, **kwargs
cls, table: str, values: tuple[str, ...], target_fields: Iterable[str], replace: bool, **kwargs
) -> str:
"""
Static helper method that generates the INSERT SQL statement.
Expand All @@ -279,7 +279,7 @@ def _generate_insert_sql(
:return: The generated INSERT or REPLACE SQL statement
"""
placeholders = [
"%s",
cls.placeholder,
] * len(values)
replace_index = kwargs.get("replace_index")

Expand All @@ -292,16 +292,20 @@ def _generate_insert_sql(
sql = f"INSERT INTO {table} {target_fields_fragment} VALUES ({','.join(placeholders)})"

if replace:
if target_fields is None:
if not target_fields:
raise ValueError("PostgreSQL ON CONFLICT upsert syntax requires column names")
if replace_index is None:
if not replace_index:
raise ValueError("PostgreSQL ON CONFLICT upsert syntax requires an unique index")
if isinstance(replace_index, str):
replace_index = [replace_index]
replace_index_set = set(replace_index)

replace_target = [
"{0} = excluded.{0}".format(col) for col in target_fields if col not in replace_index_set
]
sql += f" ON CONFLICT ({', '.join(replace_index)}) DO UPDATE SET {', '.join(replace_target)}"
on_conflict_str = f" ON CONFLICT ({', '.join(replace_index)})"
replace_target = [f for f in target_fields if f not in replace_index]

if replace_target:
replace_target_str = ", ".join(f"{col} = excluded.{col}" for col in replace_target)
sql += f"{on_conflict_str} DO UPDATE SET {replace_target_str}"
else:
sql += f"{on_conflict_str} DO NOTHING"

return sql
30 changes: 30 additions & 0 deletions tests/providers/postgres/hooks/test_postgres.py
Expand Up @@ -426,6 +426,36 @@ def test_insert_rows_replace_missing_replace_index_arg(self):

assert str(ctx.value) == "PostgreSQL ON CONFLICT upsert syntax requires an unique index"

@pytest.mark.backend("postgres")
def test_insert_rows_replace_all_index(self):
table = "table"
rows = [
(
1,
"hello",
),
(
2,
"world",
),
]
fields = ("id", "value")

self.db_hook.insert_rows(table, rows, fields, replace=True, replace_index=fields)

assert self.conn.close.call_count == 1
assert self.cur.close.call_count == 1

commit_count = 2 # The first and last commit
assert commit_count == self.conn.commit.call_count

sql = (
f"INSERT INTO {table} ({', '.join(fields)}) VALUES (%s,%s) "
f"ON CONFLICT ({', '.join(fields)}) DO NOTHING"
)
for row in rows:
self.cur.execute.assert_any_call(sql, row)

@pytest.mark.backend("postgres")
def test_rowcount(self):
hook = PostgresHook()
Expand Down

0 comments on commit d7d2061

Please sign in to comment.