From 4013efe2896c8ac2d6cee1075974863d03d7c04e Mon Sep 17 00:00:00 2001 From: Sameer Mesiah Date: Thu, 14 May 2026 01:06:39 +0100 Subject: [PATCH 1/3] Refactor shared PostgresHook runtime test coverage for psycopg2 and psycopg3. Consolidate duplicated insert/upsert and dialect tests into a shared base class while preserving version-specific behavior and lineage coverage. --- .../providers/postgres/hooks/postgres.py | 50 +- .../unit/postgres/hooks/test_postgres.py | 435 ++++++------------ 2 files changed, 163 insertions(+), 322 deletions(-) diff --git a/providers/postgres/src/airflow/providers/postgres/hooks/postgres.py b/providers/postgres/src/airflow/providers/postgres/hooks/postgres.py index d38171bbe21db..75fb5df6bddce 100644 --- a/providers/postgres/src/airflow/providers/postgres/hooks/postgres.py +++ b/providers/postgres/src/airflow/providers/postgres/hooks/postgres.py @@ -230,6 +230,29 @@ def _get_cursor(self, raw_cursor: str) -> CursorType: valid_cursors = ", ".join(cursor_types.keys()) raise ValueError(f"Invalid cursor passed {_cursor}. Valid options are: {valid_cursors}") + def _get_cursor_config(self, raw_cursor: str) -> tuple[str, Any]: + cursor = self._get_cursor(raw_cursor) + + if USE_PSYCOPG3: + return "row_factory", cursor + + return "cursor_factory", cursor + + def _create_connection(self, conn_args: dict[str, Any]) -> CompatConnection: + if USE_PSYCOPG3: + from psycopg.connection import Connection as pgConnection + + connection = pgConnection.connect(**cast("Any", conn_args)) + + register_default_adapters(connection) + + if self.enable_log_db_messages and hasattr(connection, "add_notice_handler"): + connection.add_notice_handler(self._notice_handler) + + return cast("CompatConnection", connection) + + return cast("CompatConnection", ppg2_connect(**conn_args)) + def _generate_cursor_name(self): """Generate a unique name for server-side cursor.""" import uuid @@ -262,30 +285,13 @@ def get_conn(self) -> CompatConnection: if arg_name not in self.ignored_extra_options: conn_args[arg_name] = arg_val - if USE_PSYCOPG3: - from psycopg.connection import Connection as pgConnection - - raw_cursor = conn.extra_dejson.get("cursor") - if raw_cursor: - conn_args["row_factory"] = self._get_cursor(raw_cursor) - - # Use Any type for the connection args to avoid type conflicts - connection = pgConnection.connect(**cast("Any", conn_args)) - self.conn = cast("CompatConnection", connection) - - # Register JSON handlers for both json and jsonb types - # This ensures JSON data is properly decoded from bytes to Python objects - register_default_adapters(connection) + raw_cursor = conn.extra_dejson.get("cursor") - # Add the notice handler AFTER the connection is established - if self.enable_log_db_messages and hasattr(self.conn, "add_notice_handler"): - self.conn.add_notice_handler(self._notice_handler) - else: # psycopg2 - raw_cursor = conn.extra_dejson.get("cursor", False) - if raw_cursor: - conn_args["cursor_factory"] = self._get_cursor(raw_cursor) + if raw_cursor: + key, value = self._get_cursor_config(raw_cursor) + conn_args[key] = value - self.conn = cast("CompatConnection", ppg2_connect(**conn_args)) + self.conn = self._create_connection(conn_args) return self.conn diff --git a/providers/postgres/tests/unit/postgres/hooks/test_postgres.py b/providers/postgres/tests/unit/postgres/hooks/test_postgres.py index 6cad536c07b78..fab8bc6f5d7d5 100644 --- a/providers/postgres/tests/unit/postgres/hooks/test_postgres.py +++ b/providers/postgres/tests/unit/postgres/hooks/test_postgres.py @@ -19,7 +19,6 @@ import json import os -from types import SimpleNamespace from unittest import mock import pandas as pd @@ -30,7 +29,7 @@ from airflow.models import Connection from airflow.providers.common.compat.sdk import AirflowException, AirflowOptionalProviderFeatureException from airflow.providers.postgres.dialects.postgres import PostgresDialect -from airflow.providers.postgres.hooks.postgres import CompatConnection, PostgresHook +from airflow.providers.postgres.hooks.postgres import PostgresHook from tests_common.test_utils.common_sql import mock_db_hook from tests_common.test_utils.version_compat import NOTSET @@ -59,36 +58,6 @@ import psycopg2.extras -@pytest.fixture -def postgres_hook_setup(): - """Set up mock PostgresHook for testing.""" - table = "test_postgres_hook_table" - cur = mock.MagicMock(rowcount=0) - conn = mock.MagicMock(spec=CompatConnection) - conn.cursor.return_value = cur - - class UnitTestPostgresHook(PostgresHook): - conn_name_attr = "test_conn_id" - - def get_conn(self): - return conn - - db_hook = UnitTestPostgresHook() - - # Return a namespace with all the objects - setup = SimpleNamespace(table=table, cur=cur, conn=conn, db_hook=db_hook) - - yield setup - - # Teardown - only for real database tests - try: - with PostgresHook().get_conn() as real_conn: - with real_conn.cursor() as real_cur: - real_cur.execute(f"DROP TABLE IF EXISTS {table}") - except Exception: - pass # Ignore cleanup errors for unit tests - - @pytest.fixture def mock_connect(mocker): """Mock the connection object according to the correct psycopg version.""" @@ -816,10 +785,8 @@ def test_generate_insert_sql_with_already_escaped_column_name(self): ) == INSERT_SQL_STATEMENT.format('"schema"') -@pytest.mark.backend("postgres") -@pytest.mark.skipif(USE_PSYCOPG3, reason="psycopg v3 is available") -class TestPostgresHookPPG2: - """PostgresHook tests that are specific to psycopg2.""" +class _BasePostgresHookRuntimeTests: + """Shared runtime tests for psycopg2 and psycopg3.""" table = "test_postgres_hook_table" @@ -841,6 +808,121 @@ def teardown_method(self): with conn.cursor() as cur: cur.execute(f"DROP TABLE IF EXISTS {self.table}") + def test_insert_rows(self): + table = "table" + rows = [("hello",), ("world",)] + + self.db_hook.insert_rows(table, rows) + + assert self.conn.close.call_count == 1 + assert self.cur.close.call_count == 1 + assert self.conn.commit.call_count == 2 + + sql = f"INSERT INTO {table} VALUES (%s)" + self.cur.executemany.assert_any_call(sql, rows) + + def test_insert_rows_replace(self): + table = "table" + rows = [ + (1, "hello"), + (2, "world"), + ] + fields = ("id", "value") + + self.db_hook.insert_rows( + table, + rows, + fields, + replace=True, + replace_index=fields[0], + ) + + assert self.conn.close.call_count == 1 + assert self.cur.close.call_count == 1 + assert self.conn.commit.call_count == 2 + + sql = ( + f"INSERT INTO {table} ({fields[0]}, {fields[1]}) VALUES (%s,%s) " + f"ON CONFLICT ({fields[0]}) DO UPDATE SET {fields[1]} = excluded.{fields[1]}" + ) + self.cur.executemany.assert_any_call(sql, rows) + + def test_insert_rows_replace_missing_target_field_arg(self): + table = "table" + rows = [ + (1, "hello"), + (2, "world"), + ] + fields = ("id", "value") + + with pytest.raises( + ValueError, + match="PostgreSQL ON CONFLICT upsert syntax requires column names", + ): + self.db_hook.insert_rows( + table, + rows, + replace=True, + replace_index=fields[0], + ) + + def test_insert_rows_replace_missing_replace_index_arg(self): + table = "table" + rows = [ + (1, "hello"), + (2, "world"), + ] + fields = ("id", "value") + + with pytest.raises( + ValueError, + match="PostgreSQL ON CONFLICT upsert syntax requires an unique index", + ): + self.db_hook.insert_rows( + table, + rows, + fields, + replace=True, + ) + + def test_insert_rows_replace_all_index(self): + table = "table" + rows = [ + (1, "hello"), + (2, "world"), + ] + fields = ("id", "value") + + self.db_hook.insert_rows( + table, + rows, + fields, + replace=True, + replace_index=fields, + ) + + assert self.conn.close.call_count == 1 + assert self.cur.close.call_count == 1 + assert self.conn.commit.call_count == 2 + + sql = ( + f"INSERT INTO {table} ({', '.join(fields)}) VALUES (%s,%s) " + f"ON CONFLICT ({', '.join(fields)}) DO NOTHING" + ) + self.cur.executemany.assert_any_call(sql, rows) + + def test_dialect_name(self): + assert self.db_hook.dialect_name == "postgresql" + + def test_dialect(self): + assert isinstance(self.db_hook.dialect, PostgresDialect) + + +@pytest.mark.backend("postgres") +@pytest.mark.skipif(USE_PSYCOPG3, reason="psycopg v3 is available") +class TestPostgresHookPPG2(_BasePostgresHookRuntimeTests): + """PostgresHook tests that are specific to psycopg2.""" + def test_copy_expert(self, mocker): open_mock = mocker.mock_open(read_data='{"some": "json"}') mocker.patch("airflow.providers.postgres.hooks.postgres.open", open_mock) @@ -915,169 +997,59 @@ def test_get_df_by_chunks_hook_lineage(self, mock_get_pandas_df_by_chunks, mock_ assert call_kw["sql"] == sql assert call_kw["sql_parameters"] == parameters - def test_insert_rows(self, postgres_hook_setup): - setup = postgres_hook_setup - table = "table" - rows = [("hello",), ("world",)] - - setup.db_hook.insert_rows(table, rows) - - assert setup.conn.close.call_count == 1 - assert setup.cur.close.call_count == 1 - - commit_count = 2 # The first and last commit - assert commit_count == setup.conn.commit.call_count - - sql = f"INSERT INTO {table} VALUES (%s)" - setup.cur.executemany.assert_any_call(sql, rows) - @mock.patch("airflow.providers.common.sql.hooks.sql.send_sql_hook_lineage") - def test_insert_rows_hook_lineage(self, mock_send_lineage, postgres_hook_setup): - setup = postgres_hook_setup + def test_insert_rows_hook_lineage(self, mock_send_lineage): table = "table" rows = [("hello",), ("world",)] - setup.db_hook.insert_rows(table, rows) + self.db_hook.insert_rows(table, rows) mock_send_lineage.assert_called_once() + call_kw = mock_send_lineage.call_args.kwargs - assert call_kw["context"] is setup.db_hook + + assert call_kw["context"] is self.db_hook assert call_kw["sql"] == f"INSERT INTO {table} VALUES (%s)" assert call_kw["row_count"] == 2 @mock.patch("airflow.providers.postgres.hooks.postgres.execute_batch") - def test_insert_rows_fast_executemany(self, mock_execute_batch, postgres_hook_setup): - setup = postgres_hook_setup + def test_insert_rows_fast_executemany(self, mock_execute_batch): table = "table" rows = [("hello",), ("world",)] - setup.db_hook.insert_rows(table, rows, fast_executemany=True) + self.db_hook.insert_rows(table, rows, fast_executemany=True) - assert setup.conn.close.call_count == 1 - assert setup.cur.close.call_count == 1 + assert self.conn.close.call_count == 1 + assert self.cur.close.call_count == 1 commit_count = 2 # The first and last commit - assert setup.conn.commit.call_count == commit_count + assert self.conn.commit.call_count == commit_count mock_execute_batch.assert_called_once_with( - setup.cur, + self.cur, f"INSERT INTO {table} VALUES (%s)", # expected SQL [("hello",), ("world",)], # expected values page_size=1000, ) # executemany should NOT be called in this mode - setup.cur.executemany.assert_not_called() + self.cur.executemany.assert_not_called() @mock.patch("airflow.providers.postgres.hooks.postgres.send_sql_hook_lineage") @mock.patch("airflow.providers.postgres.hooks.postgres.execute_batch") - def test_insert_rows_fast_executemany_hook_lineage( - self, mock_execute_batch, mock_send_lineage, postgres_hook_setup - ): - setup = postgres_hook_setup + def test_insert_rows_fast_executemany_hook_lineage(self, mock_execute_batch, mock_send_lineage): + table = "table" rows = [("hello",), ("world",)] - setup.db_hook.insert_rows(table, rows, fast_executemany=True) + self.db_hook.insert_rows(table, rows, fast_executemany=True) mock_send_lineage.assert_called_once() call_kw = mock_send_lineage.call_args.kwargs - assert call_kw["context"] is setup.db_hook + assert call_kw["context"] is self.db_hook assert call_kw["sql"] == f"INSERT INTO {table} VALUES (%s)" assert call_kw["row_count"] == 2 - def test_insert_rows_replace(self, postgres_hook_setup): - setup = postgres_hook_setup - table = "table" - rows = [ - ( - 1, - "hello", - ), - ( - 2, - "world", - ), - ] - fields = ("id", "value") - - setup.db_hook.insert_rows(table, rows, fields, replace=True, replace_index=fields[0]) - - assert setup.conn.close.call_count == 1 - assert setup.cur.close.call_count == 1 - - commit_count = 2 # The first and last commit - assert commit_count == setup.conn.commit.call_count - - sql = ( - f"INSERT INTO {table} ({fields[0]}, {fields[1]}) VALUES (%s,%s) " - f"ON CONFLICT ({fields[0]}) DO UPDATE SET {fields[1]} = excluded.{fields[1]}" - ) - setup.cur.executemany.assert_any_call(sql, rows) - - def test_insert_rows_replace_missing_target_field_arg(self, postgres_hook_setup): - setup = postgres_hook_setup - table = "table" - rows = [ - ( - 1, - "hello", - ), - ( - 2, - "world", - ), - ] - fields = ("id", "value") - with pytest.raises(ValueError, match="PostgreSQL ON CONFLICT upsert syntax requires column names"): - setup.db_hook.insert_rows(table, rows, replace=True, replace_index=fields[0]) - - def test_insert_rows_replace_missing_replace_index_arg(self, postgres_hook_setup): - setup = postgres_hook_setup - table = "table" - rows = [ - ( - 1, - "hello", - ), - ( - 2, - "world", - ), - ] - fields = ("id", "value") - with pytest.raises(ValueError, match="PostgreSQL ON CONFLICT upsert syntax requires an unique index"): - setup.db_hook.insert_rows(table, rows, fields, replace=True) - - def test_insert_rows_replace_all_index(self, postgres_hook_setup): - setup = postgres_hook_setup - table = "table" - rows = [ - ( - 1, - "hello", - ), - ( - 2, - "world", - ), - ] - fields = ("id", "value") - - setup.db_hook.insert_rows(table, rows, fields, replace=True, replace_index=fields) - - assert setup.conn.close.call_count == 1 - assert setup.cur.close.call_count == 1 - - commit_count = 2 # The first and last commit - assert commit_count == setup.conn.commit.call_count - - sql = ( - f"INSERT INTO {table} ({', '.join(fields)}) VALUES (%s,%s) " - f"ON CONFLICT ({', '.join(fields)}) DO NOTHING" - ) - setup.cur.executemany.assert_any_call(sql, rows) - @pytest.mark.usefixtures("reset_logging_config") def test_get_all_db_log_messages(self, mocker): messages = ["a", "b", "c"] @@ -1120,40 +1092,12 @@ def test_log_db_messages_by_db_proc(self, mocker): finally: hook.run(sql=f"DROP PROCEDURE {proc_name} (s text)") - def test_dialect_name(self, postgres_hook_setup): - setup = postgres_hook_setup - assert setup.db_hook.dialect_name == "postgresql" - - def test_dialect(self, postgres_hook_setup): - setup = postgres_hook_setup - assert isinstance(setup.db_hook.dialect, PostgresDialect) - @pytest.mark.backend("postgres") @pytest.mark.skipif(not USE_PSYCOPG3, reason="psycopg v3 or sqlalchemy v2 are not available") -class TestPostgresHookPPG3: +class TestPostgresHookPPG3(_BasePostgresHookRuntimeTests): """PostgresHook tests that are specific to psycopg3.""" - table = "test_postgres_hook_table" - - def setup_method(self): - self.cur = mock.MagicMock(rowcount=0) - self.conn = conn = mock.MagicMock() - self.conn.cursor.return_value = self.cur - - class UnitTestPostgresHook(PostgresHook): - conn_name_attr = "test_conn_id" - - def get_conn(self): - return conn - - self.db_hook = UnitTestPostgresHook() - - def teardown_method(self): - with PostgresHook().get_conn() as conn: - with conn.cursor() as cur: - cur.execute(f"DROP TABLE IF EXISTS {self.table}") - def test_copy_expert_from(self, mocker): """Tests copy_expert with a 'COPY FROM STDIN' operation.""" statement = "COPY test_table FROM STDIN" @@ -1235,109 +1179,6 @@ def test_copy_expert_to(self, mocker): ) self.conn.commit.assert_called_once() - def test_insert_rows(self): - table = "table" - rows = [("hello",), ("world",)] - - self.db_hook.insert_rows(table, rows) - - assert self.conn.close.call_count == 1 - assert self.cur.close.call_count == 1 - - commit_count = 2 # The first and last commit - assert commit_count == self.conn.commit.call_count - - sql = f"INSERT INTO {table} VALUES (%s)" - self.cur.executemany.assert_any_call(sql, rows) - - def test_insert_rows_replace(self): - table = "table" - rows = [ - ( - 1, - "hello", - ), - ( - 2, - "world", - ), - ] - fields = ("id", "value") - - self.db_hook.insert_rows(table, rows, fields, replace=True, replace_index=fields[0]) - - assert self.conn.close.call_count == 1 - assert self.cur.close.call_count == 1 - - commit_count = 2 # The first and last commit - assert commit_count == self.conn.commit.call_count - - sql = ( - f"INSERT INTO {table} ({fields[0]}, {fields[1]}) VALUES (%s,%s) " - f"ON CONFLICT ({fields[0]}) DO UPDATE SET {fields[1]} = excluded.{fields[1]}" - ) - self.cur.executemany.assert_any_call(sql, rows) - - def test_insert_rows_replace_missing_target_field_arg(self): - table = "table" - rows = [ - ( - 1, - "hello", - ), - ( - 2, - "world", - ), - ] - fields = ("id", "value") - with pytest.raises(ValueError, match="PostgreSQL ON CONFLICT upsert syntax requires column names"): - self.db_hook.insert_rows(table, rows, replace=True, replace_index=fields[0]) - - def test_insert_rows_replace_missing_replace_index_arg(self): - table = "table" - rows = [ - ( - 1, - "hello", - ), - ( - 2, - "world", - ), - ] - fields = ("id", "value") - with pytest.raises(ValueError, match="PostgreSQL ON CONFLICT upsert syntax requires an unique index"): - self.db_hook.insert_rows(table, rows, fields, replace=True) - - def test_insert_rows_replace_all_index(self): - table = "table" - rows = [ - ( - 1, - "hello", - ), - ( - 2, - "world", - ), - ] - fields = ("id", "value") - - self.db_hook.insert_rows(table, rows, fields, replace=True, replace_index=fields) - - assert self.conn.close.call_count == 1 - assert self.cur.close.call_count == 1 - - commit_count = 2 # The first and last commit - assert commit_count == self.conn.commit.call_count - - sql = ( - f"INSERT INTO {table} ({', '.join(fields)}) VALUES (%s,%s) " - f"ON CONFLICT ({', '.join(fields)}) DO NOTHING" - ) - self.cur.executemany.assert_any_call(sql, rows) - @pytest.mark.skip(reason="Notice handling is callback-based in psycopg3 and cannot be tested this way.") def test_get_all_db_log_messages(self, mocker): pass @@ -1366,9 +1207,3 @@ def test_log_db_messages_by_db_proc(self, mocker): mock_logger.info.assert_any_call("Message from db: 42") finally: hook.run(sql=f"DROP PROCEDURE {proc_name} (s text)") - - def test_dialect_name(self): - assert self.db_hook.dialect_name == "postgresql" - - def test_dialect(self): - assert isinstance(self.db_hook.dialect, PostgresDialect) From c71bffd5878b493cdf7522f281c448b0fcbef5da Mon Sep 17 00:00:00 2001 From: Sameer Mesiah Date: Thu, 14 May 2026 19:31:54 +0100 Subject: [PATCH 2/3] Remove unnecesary typing casts --- .../postgres/src/airflow/providers/postgres/hooks/postgres.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/providers/postgres/src/airflow/providers/postgres/hooks/postgres.py b/providers/postgres/src/airflow/providers/postgres/hooks/postgres.py index 75fb5df6bddce..ec28bcc834fed 100644 --- a/providers/postgres/src/airflow/providers/postgres/hooks/postgres.py +++ b/providers/postgres/src/airflow/providers/postgres/hooks/postgres.py @@ -249,9 +249,9 @@ def _create_connection(self, conn_args: dict[str, Any]) -> CompatConnection: if self.enable_log_db_messages and hasattr(connection, "add_notice_handler"): connection.add_notice_handler(self._notice_handler) - return cast("CompatConnection", connection) + return connection - return cast("CompatConnection", ppg2_connect(**conn_args)) + return ppg2_connect(**conn_args) def _generate_cursor_name(self): """Generate a unique name for server-side cursor.""" From 5709c065c5ada02961c4a4a9d909db584e3c2c5b Mon Sep 17 00:00:00 2001 From: Sameer Mesiah Date: Sat, 16 May 2026 17:05:27 +0100 Subject: [PATCH 3/3] Add PostgreSQL UPSERT support to PostgresHook using ON CONFLICT. Introduce upsert_rows and UPSERT SQL generation with support for DO UPDATE, DO NOTHING, chunked commits, and execute_batch optimization. Added unit tests for UPSERT SQL generation, DO UPDATE and DO NOTHING behavior, chunked execution, execute_batch execution, validation checks, and transaction handling. --- .../providers/postgres/hooks/postgres.py | 136 ++++++++++++- .../unit/postgres/hooks/test_postgres.py | 190 ++++++++++++++++++ 2 files changed, 325 insertions(+), 1 deletion(-) diff --git a/providers/postgres/src/airflow/providers/postgres/hooks/postgres.py b/providers/postgres/src/airflow/providers/postgres/hooks/postgres.py index ec28bcc834fed..dd02fa862e9a8 100644 --- a/providers/postgres/src/airflow/providers/postgres/hooks/postgres.py +++ b/providers/postgres/src/airflow/providers/postgres/hooks/postgres.py @@ -18,7 +18,7 @@ from __future__ import annotations import os -from collections.abc import Mapping +from collections.abc import Iterable, Mapping from contextlib import closing from copy import deepcopy from typing import TYPE_CHECKING, Any, Literal, Protocol, TypeAlias, cast, overload @@ -726,3 +726,137 @@ def insert_rows( self.log.info("Done loading. Loaded a total of %s rows into %s", nb_rows, table) return None + + def _generate_upsert_sql( + self, + table: str, + values: tuple[Any, ...] | list[Any], + target_fields: list[str], + conflict_fields: list[str], + update_fields: list[str] | None = None, + **kwargs, + ) -> str: + """ + Generate PostgreSQL UPSERT SQL using ``ON CONFLICT``. + + :param table: Name of target table. + :param values: Row values used for placeholder generation. + :param target_fields: Non-empty column names used in the ``INSERT`` statement. + :param conflict_fields: Non-empty column names used in the ``ON CONFLICT`` clause. + :param update_fields: Columns to update on conflict. If omitted or empty, + ``DO NOTHING`` is used. + """ + placeholders = ", ".join(["%s"] * len(values)) + + columns = ", ".join(target_fields) + + conflict_clause = ", ".join(conflict_fields) + + sql = f"INSERT INTO {table} ({columns}) VALUES ({placeholders}) " + + if update_fields: + update_clause = ", ".join(f"{field} = EXCLUDED.{field}" for field in update_fields) + + sql += f"ON CONFLICT ({conflict_clause}) DO UPDATE SET {update_clause}" + else: + sql += f"ON CONFLICT ({conflict_clause}) DO NOTHING" + + return sql.strip() + + def upsert_rows( + self, + table: str, + rows: Iterable[tuple[Any, ...]], + target_fields: list[str], + conflict_fields: list[str], + update_fields: list[str] | None = None, + commit_every: int = 1000, + *, + fast_executemany: bool = False, + autocommit: bool = False, + ) -> None: + """ + Upsert rows into a PostgreSQL table using ``ON CONFLICT``. + + :param table: Name of the target table. + :param rows: Rows to upsert. + :param target_fields: Non-empty column names used in the ``INSERT`` statement. + :param conflict_fields: Non-empty column names used in the ``ON CONFLICT`` clause. + :param update_fields: Columns updated on conflict. If omitted or empty, + conflicting rows are ignored via ``DO NOTHING``. + :param commit_every: Maximum number of rows per transaction. Default value is 1000. + :param fast_executemany: Use ``psycopg2.extras.execute_batch`` for improved + batch performance. + :param autocommit: Connection autocommit setting. + """ + if not target_fields or any(not field for field in target_fields): + raise ValueError("target_fields must be provided and must not be empty.") + + if not conflict_fields or any(not field for field in conflict_fields): + raise ValueError("conflict_fields must be provided and must not be empty.") + + rows = iter(rows) + + nb_rows = 0 + sql = None + + with self._create_autocommit_connection(autocommit) as conn: + conn.commit() + + with closing(conn.cursor()) as cur: + for chunked_rows in chunked(rows, commit_every): + values = [self._serialize_cells(row, conn) for row in chunked_rows] + + if not values: + continue + + sql = self._generate_upsert_sql( + table=table, + values=values[0], + target_fields=target_fields, + conflict_fields=conflict_fields, + update_fields=update_fields, + ) + + self.log.debug("Generated sql: %s", sql) + + try: + if fast_executemany: + # execute_batch reduces round trips by batching parameter sets. + execute_batch( + cur, + sql, + values, + page_size=commit_every, + ) + else: + cur.executemany(sql, values) + + except Exception: + self.log.error("Generated sql: %s", sql) + self.log.error("Parameters: %s", values) + raise + + conn.commit() + + nb_rows += len(values) + + self.log.info( + "Upserted %s rows into %s so far", + nb_rows, + table, + ) + + if sql: + send_sql_hook_lineage( + context=self, + sql=sql, + row_count=nb_rows, + ) + + self.log.info( + "Done upserting. Upserted a total of %s rows into %s", + nb_rows, + table, + ) + return None diff --git a/providers/postgres/tests/unit/postgres/hooks/test_postgres.py b/providers/postgres/tests/unit/postgres/hooks/test_postgres.py index fab8bc6f5d7d5..4d32f479ebab3 100644 --- a/providers/postgres/tests/unit/postgres/hooks/test_postgres.py +++ b/providers/postgres/tests/unit/postgres/hooks/test_postgres.py @@ -911,6 +911,196 @@ def test_insert_rows_replace_all_index(self): ) self.cur.executemany.assert_any_call(sql, rows) + @pytest.mark.parametrize( + ( + "rows", + "target_fields", + "conflict_fields", + "update_fields", + "expected_sql", + ), + [ + ( + [(1, "hello"), (2, "world")], + ["id", "value"], + ["id"], + ["value"], + ( + "INSERT INTO table (id, value) " + "VALUES (%s, %s) " + "ON CONFLICT (id) " + "DO UPDATE SET value = EXCLUDED.value" + ), + ), + ( + [("hello",)], + ["value"], + ["value"], + None, + ("INSERT INTO table (value) VALUES (%s) ON CONFLICT (value) DO NOTHING"), + ), + ( + [(1, 10, "hello")], + ["id", "tenant_id", "value"], + ["id", "tenant_id"], + ["value"], + ( + "INSERT INTO table (id, tenant_id, value) " + "VALUES (%s, %s, %s) " + "ON CONFLICT (id, tenant_id) " + "DO UPDATE SET value = EXCLUDED.value" + ), + ), + ( + [(1, "alice", "a@example.com")], + ["id", "name", "email"], + ["id"], + ["name", "email"], + ( + "INSERT INTO table (id, name, email) " + "VALUES (%s, %s, %s) " + "ON CONFLICT (id) " + "DO UPDATE SET " + "name = EXCLUDED.name, email = EXCLUDED.email" + ), + ), + ], + ) + def test_upsert_rows( + self, + rows, + target_fields, + conflict_fields, + update_fields, + expected_sql, + ): + self.db_hook.upsert_rows( + table="table", + rows=rows, + target_fields=target_fields, + conflict_fields=conflict_fields, + update_fields=update_fields, + ) + + assert self.conn.close.call_count == 1 + assert self.cur.close.call_count == 1 + assert self.conn.commit.call_count == 2 + + self.cur.executemany.assert_called_once_with( + expected_sql, + rows, + ) + + @mock.patch("airflow.providers.postgres.hooks.postgres.execute_batch") + def test_upsert_rows_fast_executemany(self, mock_execute_batch): + + table = "table" + rows = [ + (1, "hello"), + (2, "world"), + ] + + self.db_hook.upsert_rows( + table=table, + rows=rows, + target_fields=["id", "value"], + conflict_fields=["id"], + update_fields=["value"], + fast_executemany=True, + ) + + sql = ( + "INSERT INTO table (id, value) " + "VALUES (%s, %s) " + "ON CONFLICT (id) " + "DO UPDATE SET value = EXCLUDED.value" + ) + + mock_execute_batch.assert_called_once_with( + self.cur, + sql, + rows, + page_size=1000, + ) + + self.cur.executemany.assert_not_called() + + def test_upsert_rows_commit_every(self): + rows = [ + (1, "hello"), + (2, "world"), + ] + + self.db_hook.upsert_rows( + table="table", + rows=rows, + target_fields=["id", "value"], + conflict_fields=["id"], + update_fields=["value"], + commit_every=1, + ) + + sql = ( + "INSERT INTO table (id, value) " + "VALUES (%s, %s) " + "ON CONFLICT (id) " + "DO UPDATE SET value = EXCLUDED.value" + ) + + assert self.cur.executemany.call_count == 2 + + self.cur.executemany.assert_has_calls( + [ + mock.call(sql, [(1, "hello")]), + mock.call(sql, [(2, "world")]), + ] + ) + + # Initial commit and one commit per chunk. + assert self.conn.commit.call_count == 3 + + @mock.patch("airflow.providers.postgres.hooks.postgres.send_sql_hook_lineage") + def test_upsert_rows_empty_rows(self, mock_lineage): + self.db_hook.upsert_rows( + table="table", + rows=[], + target_fields=["id", "value"], + conflict_fields=["id"], + update_fields=["value"], + ) + + self.cur.executemany.assert_not_called() + + # Only initial transaction reset commit should occur. + assert self.conn.commit.call_count == 1 + + mock_lineage.assert_not_called() + + @pytest.mark.parametrize( + ("target_fields", "conflict_fields", "error"), + [ + ([], ["id"], "target_fields must be provided and must not be empty."), + ([""], ["id"], "target_fields must be provided and must not be empty."), + (["id", ""], ["id"], "target_fields must be provided and must not be empty."), + (["id"], [], "conflict_fields must be provided and must not be empty."), + (["id"], [""], "conflict_fields must be provided and must not be empty."), + (["id"], ["id", ""], "conflict_fields must be provided and must not be empty."), + ], + ) + def test_upsert_rows_invalid_fields( + self, + target_fields, + conflict_fields, + error, + ): + with pytest.raises(ValueError, match=error): + self.db_hook.upsert_rows( + table="table", + rows=[(1, "hello")], + target_fields=target_fields, + conflict_fields=conflict_fields, + ) + def test_dialect_name(self): assert self.db_hook.dialect_name == "postgresql"