Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
186 changes: 163 additions & 23 deletions providers/postgres/src/airflow/providers/postgres/hooks/postgres.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from __future__ import annotations

import os
from collections.abc import Mapping
from collections.abc import Iterable, Mapping
from contextlib import closing
from copy import deepcopy
from typing import TYPE_CHECKING, Any, Literal, Protocol, TypeAlias, cast, overload
Expand Down Expand Up @@ -230,6 +230,29 @@ def _get_cursor(self, raw_cursor: str) -> CursorType:
valid_cursors = ", ".join(cursor_types.keys())
raise ValueError(f"Invalid cursor passed {_cursor}. Valid options are: {valid_cursors}")

def _get_cursor_config(self, raw_cursor: str) -> tuple[str, Any]:
cursor = self._get_cursor(raw_cursor)

if USE_PSYCOPG3:
return "row_factory", cursor

return "cursor_factory", cursor

def _create_connection(self, conn_args: dict[str, Any]) -> CompatConnection:
if USE_PSYCOPG3:
from psycopg.connection import Connection as pgConnection

connection = pgConnection.connect(**cast("Any", conn_args))

register_default_adapters(connection)

if self.enable_log_db_messages and hasattr(connection, "add_notice_handler"):
connection.add_notice_handler(self._notice_handler)

return connection

return ppg2_connect(**conn_args)

def _generate_cursor_name(self):
"""Generate a unique name for server-side cursor."""
import uuid
Expand Down Expand Up @@ -262,30 +285,13 @@ def get_conn(self) -> CompatConnection:
if arg_name not in self.ignored_extra_options:
conn_args[arg_name] = arg_val

if USE_PSYCOPG3:
from psycopg.connection import Connection as pgConnection
raw_cursor = conn.extra_dejson.get("cursor")

raw_cursor = conn.extra_dejson.get("cursor")
if raw_cursor:
conn_args["row_factory"] = self._get_cursor(raw_cursor)
if raw_cursor:
key, value = self._get_cursor_config(raw_cursor)
conn_args[key] = value

# Use Any type for the connection args to avoid type conflicts
connection = pgConnection.connect(**cast("Any", conn_args))
self.conn = cast("CompatConnection", connection)

# Register JSON handlers for both json and jsonb types
# This ensures JSON data is properly decoded from bytes to Python objects
register_default_adapters(connection)

# Add the notice handler AFTER the connection is established
if self.enable_log_db_messages and hasattr(self.conn, "add_notice_handler"):
self.conn.add_notice_handler(self._notice_handler)
else: # psycopg2
raw_cursor = conn.extra_dejson.get("cursor", False)
if raw_cursor:
conn_args["cursor_factory"] = self._get_cursor(raw_cursor)

self.conn = cast("CompatConnection", ppg2_connect(**conn_args))
self.conn = self._create_connection(conn_args)

return self.conn

Expand Down Expand Up @@ -720,3 +726,137 @@ def insert_rows(

self.log.info("Done loading. Loaded a total of %s rows into %s", nb_rows, table)
return None

def _generate_upsert_sql(
self,
table: str,
values: tuple[Any, ...] | list[Any],
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need to pass in the values here? The only thing it's used for is to produce the right number of placeholders but I think that can just be done with len(target_fields).

target_fields: list[str],
conflict_fields: list[str],
update_fields: list[str] | None = None,
**kwargs,
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think we need the kwargs here as there is nothing consuming them.

) -> str:
"""
Generate PostgreSQL UPSERT SQL using ``ON CONFLICT``.

:param table: Name of target table.
:param values: Row values used for placeholder generation.
:param target_fields: Non-empty column names used in the ``INSERT`` statement.
:param conflict_fields: Non-empty column names used in the ``ON CONFLICT`` clause.
:param update_fields: Columns to update on conflict. If omitted or empty,
``DO NOTHING`` is used.
"""
placeholders = ", ".join(["%s"] * len(values))

columns = ", ".join(target_fields)

conflict_clause = ", ".join(conflict_fields)

sql = f"INSERT INTO {table} ({columns}) VALUES ({placeholders}) "

if update_fields:
update_clause = ", ".join(f"{field} = EXCLUDED.{field}" for field in update_fields)

sql += f"ON CONFLICT ({conflict_clause}) DO UPDATE SET {update_clause}"
else:
sql += f"ON CONFLICT ({conflict_clause}) DO NOTHING"

return sql.strip()

def upsert_rows(
self,
table: str,
rows: Iterable[tuple[Any, ...]],
target_fields: list[str],
conflict_fields: list[str],
update_fields: list[str] | None = None,
commit_every: int = 1000,
*,
fast_executemany: bool = False,
autocommit: bool = False,
) -> None:
"""
Upsert rows into a PostgreSQL table using ``ON CONFLICT``.

:param table: Name of the target table.
:param rows: Rows to upsert.
:param target_fields: Non-empty column names used in the ``INSERT`` statement.
:param conflict_fields: Non-empty column names used in the ``ON CONFLICT`` clause.
:param update_fields: Columns updated on conflict. If omitted or empty,
conflicting rows are ignored via ``DO NOTHING``.
:param commit_every: Maximum number of rows per transaction. Default value is 1000.
:param fast_executemany: Use ``psycopg2.extras.execute_batch`` for improved
batch performance.
:param autocommit: Connection autocommit setting.
"""
if not target_fields or any(not field for field in target_fields):
raise ValueError("target_fields must be provided and must not be empty.")

if not conflict_fields or any(not field for field in conflict_fields):
raise ValueError("conflict_fields must be provided and must not be empty.")

rows = iter(rows)

nb_rows = 0
sql = None

with self._create_autocommit_connection(autocommit) as conn:
conn.commit()

with closing(conn.cursor()) as cur:
for chunked_rows in chunked(rows, commit_every):
values = [self._serialize_cells(row, conn) for row in chunked_rows]

if not values:
continue

sql = self._generate_upsert_sql(
table=table,
values=values[0],
target_fields=target_fields,
conflict_fields=conflict_fields,
update_fields=update_fields,
)
Comment on lines +813 to +819
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Related to my comment above - if we don't need to pass values since it's just used for the number of placeholders, I don't think we need to regenerate the same SQL string on every chunk. This can just be done once outside of the loop.


self.log.debug("Generated sql: %s", sql)

try:
if fast_executemany:
# execute_batch reduces round trips by batching parameter sets.
execute_batch(
Comment on lines +824 to +826
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need a guard here since psycopg3 does not support execute_batch? Maybe using the USE_PSYCOPG3 constant that's used in other parts of the code. Either logging a warning and defaulting back to cur.executemany or raising an error.

cur,
sql,
values,
page_size=commit_every,
)
else:
cur.executemany(sql, values)

except Exception:
self.log.error("Generated sql: %s", sql)
self.log.error("Parameters: %s", values)
raise

conn.commit()

nb_rows += len(values)

self.log.info(
"Upserted %s rows into %s so far",
nb_rows,
table,
)

if sql:
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also related to my comment above, if we construct the query once outside the loop then this would need to be updated. This could be if nb_rows > 0.

send_sql_hook_lineage(
context=self,
sql=sql,
row_count=nb_rows,
)

self.log.info(
"Done upserting. Upserted a total of %s rows into %s",
nb_rows,
table,
)
return None
Loading
Loading