diff --git a/providers/postgres/src/airflow/providers/postgres/hooks/postgres.py b/providers/postgres/src/airflow/providers/postgres/hooks/postgres.py index d38171bbe21db..ec28bcc834fed 100644 --- a/providers/postgres/src/airflow/providers/postgres/hooks/postgres.py +++ b/providers/postgres/src/airflow/providers/postgres/hooks/postgres.py @@ -230,6 +230,29 @@ def _get_cursor(self, raw_cursor: str) -> CursorType: valid_cursors = ", ".join(cursor_types.keys()) raise ValueError(f"Invalid cursor passed {_cursor}. Valid options are: {valid_cursors}") + def _get_cursor_config(self, raw_cursor: str) -> tuple[str, Any]: + cursor = self._get_cursor(raw_cursor) + + if USE_PSYCOPG3: + return "row_factory", cursor + + return "cursor_factory", cursor + + def _create_connection(self, conn_args: dict[str, Any]) -> CompatConnection: + if USE_PSYCOPG3: + from psycopg.connection import Connection as pgConnection + + connection = pgConnection.connect(**cast("Any", conn_args)) + + register_default_adapters(connection) + + if self.enable_log_db_messages and hasattr(connection, "add_notice_handler"): + connection.add_notice_handler(self._notice_handler) + + return connection + + return ppg2_connect(**conn_args) + def _generate_cursor_name(self): """Generate a unique name for server-side cursor.""" import uuid @@ -262,30 +285,13 @@ def get_conn(self) -> CompatConnection: if arg_name not in self.ignored_extra_options: conn_args[arg_name] = arg_val - if USE_PSYCOPG3: - from psycopg.connection import Connection as pgConnection - - raw_cursor = conn.extra_dejson.get("cursor") - if raw_cursor: - conn_args["row_factory"] = self._get_cursor(raw_cursor) - - # Use Any type for the connection args to avoid type conflicts - connection = pgConnection.connect(**cast("Any", conn_args)) - self.conn = cast("CompatConnection", connection) - - # Register JSON handlers for both json and jsonb types - # This ensures JSON data is properly decoded from bytes to Python objects - register_default_adapters(connection) + raw_cursor = conn.extra_dejson.get("cursor") - # Add the notice handler AFTER the connection is established - if self.enable_log_db_messages and hasattr(self.conn, "add_notice_handler"): - self.conn.add_notice_handler(self._notice_handler) - else: # psycopg2 - raw_cursor = conn.extra_dejson.get("cursor", False) - if raw_cursor: - conn_args["cursor_factory"] = self._get_cursor(raw_cursor) + if raw_cursor: + key, value = self._get_cursor_config(raw_cursor) + conn_args[key] = value - self.conn = cast("CompatConnection", ppg2_connect(**conn_args)) + self.conn = self._create_connection(conn_args) return self.conn diff --git a/providers/postgres/tests/unit/postgres/hooks/test_postgres.py b/providers/postgres/tests/unit/postgres/hooks/test_postgres.py index 6cad536c07b78..fab8bc6f5d7d5 100644 --- a/providers/postgres/tests/unit/postgres/hooks/test_postgres.py +++ b/providers/postgres/tests/unit/postgres/hooks/test_postgres.py @@ -19,7 +19,6 @@ import json import os -from types import SimpleNamespace from unittest import mock import pandas as pd @@ -30,7 +29,7 @@ from airflow.models import Connection from airflow.providers.common.compat.sdk import AirflowException, AirflowOptionalProviderFeatureException from airflow.providers.postgres.dialects.postgres import PostgresDialect -from airflow.providers.postgres.hooks.postgres import CompatConnection, PostgresHook +from airflow.providers.postgres.hooks.postgres import PostgresHook from tests_common.test_utils.common_sql import mock_db_hook from tests_common.test_utils.version_compat import NOTSET @@ -59,36 +58,6 @@ import psycopg2.extras -@pytest.fixture -def postgres_hook_setup(): - """Set up mock PostgresHook for testing.""" - table = "test_postgres_hook_table" - cur = mock.MagicMock(rowcount=0) - conn = mock.MagicMock(spec=CompatConnection) - conn.cursor.return_value = cur - - class UnitTestPostgresHook(PostgresHook): - conn_name_attr = "test_conn_id" - - def get_conn(self): - return conn - - db_hook = UnitTestPostgresHook() - - # Return a namespace with all the objects - setup = SimpleNamespace(table=table, cur=cur, conn=conn, db_hook=db_hook) - - yield setup - - # Teardown - only for real database tests - try: - with PostgresHook().get_conn() as real_conn: - with real_conn.cursor() as real_cur: - real_cur.execute(f"DROP TABLE IF EXISTS {table}") - except Exception: - pass # Ignore cleanup errors for unit tests - - @pytest.fixture def mock_connect(mocker): """Mock the connection object according to the correct psycopg version.""" @@ -816,10 +785,8 @@ def test_generate_insert_sql_with_already_escaped_column_name(self): ) == INSERT_SQL_STATEMENT.format('"schema"') -@pytest.mark.backend("postgres") -@pytest.mark.skipif(USE_PSYCOPG3, reason="psycopg v3 is available") -class TestPostgresHookPPG2: - """PostgresHook tests that are specific to psycopg2.""" +class _BasePostgresHookRuntimeTests: + """Shared runtime tests for psycopg2 and psycopg3.""" table = "test_postgres_hook_table" @@ -841,6 +808,121 @@ def teardown_method(self): with conn.cursor() as cur: cur.execute(f"DROP TABLE IF EXISTS {self.table}") + def test_insert_rows(self): + table = "table" + rows = [("hello",), ("world",)] + + self.db_hook.insert_rows(table, rows) + + assert self.conn.close.call_count == 1 + assert self.cur.close.call_count == 1 + assert self.conn.commit.call_count == 2 + + sql = f"INSERT INTO {table} VALUES (%s)" + self.cur.executemany.assert_any_call(sql, rows) + + def test_insert_rows_replace(self): + table = "table" + rows = [ + (1, "hello"), + (2, "world"), + ] + fields = ("id", "value") + + self.db_hook.insert_rows( + table, + rows, + fields, + replace=True, + replace_index=fields[0], + ) + + assert self.conn.close.call_count == 1 + assert self.cur.close.call_count == 1 + assert self.conn.commit.call_count == 2 + + sql = ( + f"INSERT INTO {table} ({fields[0]}, {fields[1]}) VALUES (%s,%s) " + f"ON CONFLICT ({fields[0]}) DO UPDATE SET {fields[1]} = excluded.{fields[1]}" + ) + self.cur.executemany.assert_any_call(sql, rows) + + def test_insert_rows_replace_missing_target_field_arg(self): + table = "table" + rows = [ + (1, "hello"), + (2, "world"), + ] + fields = ("id", "value") + + with pytest.raises( + ValueError, + match="PostgreSQL ON CONFLICT upsert syntax requires column names", + ): + self.db_hook.insert_rows( + table, + rows, + replace=True, + replace_index=fields[0], + ) + + def test_insert_rows_replace_missing_replace_index_arg(self): + table = "table" + rows = [ + (1, "hello"), + (2, "world"), + ] + fields = ("id", "value") + + with pytest.raises( + ValueError, + match="PostgreSQL ON CONFLICT upsert syntax requires an unique index", + ): + self.db_hook.insert_rows( + table, + rows, + fields, + replace=True, + ) + + def test_insert_rows_replace_all_index(self): + table = "table" + rows = [ + (1, "hello"), + (2, "world"), + ] + fields = ("id", "value") + + self.db_hook.insert_rows( + table, + rows, + fields, + replace=True, + replace_index=fields, + ) + + assert self.conn.close.call_count == 1 + assert self.cur.close.call_count == 1 + assert self.conn.commit.call_count == 2 + + sql = ( + f"INSERT INTO {table} ({', '.join(fields)}) VALUES (%s,%s) " + f"ON CONFLICT ({', '.join(fields)}) DO NOTHING" + ) + self.cur.executemany.assert_any_call(sql, rows) + + def test_dialect_name(self): + assert self.db_hook.dialect_name == "postgresql" + + def test_dialect(self): + assert isinstance(self.db_hook.dialect, PostgresDialect) + + +@pytest.mark.backend("postgres") +@pytest.mark.skipif(USE_PSYCOPG3, reason="psycopg v3 is available") +class TestPostgresHookPPG2(_BasePostgresHookRuntimeTests): + """PostgresHook tests that are specific to psycopg2.""" + def test_copy_expert(self, mocker): open_mock = mocker.mock_open(read_data='{"some": "json"}') mocker.patch("airflow.providers.postgres.hooks.postgres.open", open_mock) @@ -915,169 +997,59 @@ def test_get_df_by_chunks_hook_lineage(self, mock_get_pandas_df_by_chunks, mock_ assert call_kw["sql"] == sql assert call_kw["sql_parameters"] == parameters - def test_insert_rows(self, postgres_hook_setup): - setup = postgres_hook_setup - table = "table" - rows = [("hello",), ("world",)] - - setup.db_hook.insert_rows(table, rows) - - assert setup.conn.close.call_count == 1 - assert setup.cur.close.call_count == 1 - - commit_count = 2 # The first and last commit - assert commit_count == setup.conn.commit.call_count - - sql = f"INSERT INTO {table} VALUES (%s)" - setup.cur.executemany.assert_any_call(sql, rows) - @mock.patch("airflow.providers.common.sql.hooks.sql.send_sql_hook_lineage") - def test_insert_rows_hook_lineage(self, mock_send_lineage, postgres_hook_setup): - setup = postgres_hook_setup + def test_insert_rows_hook_lineage(self, mock_send_lineage): table = "table" rows = [("hello",), ("world",)] - setup.db_hook.insert_rows(table, rows) + self.db_hook.insert_rows(table, rows) mock_send_lineage.assert_called_once() + call_kw = mock_send_lineage.call_args.kwargs - assert call_kw["context"] is setup.db_hook + + assert call_kw["context"] is self.db_hook assert call_kw["sql"] == f"INSERT INTO {table} VALUES (%s)" assert call_kw["row_count"] == 2 @mock.patch("airflow.providers.postgres.hooks.postgres.execute_batch") - def test_insert_rows_fast_executemany(self, mock_execute_batch, postgres_hook_setup): - setup = postgres_hook_setup + def test_insert_rows_fast_executemany(self, mock_execute_batch): table = "table" rows = [("hello",), ("world",)] - setup.db_hook.insert_rows(table, rows, fast_executemany=True) + self.db_hook.insert_rows(table, rows, fast_executemany=True) - assert setup.conn.close.call_count == 1 - assert setup.cur.close.call_count == 1 + assert self.conn.close.call_count == 1 + assert self.cur.close.call_count == 1 commit_count = 2 # The first and last commit - assert setup.conn.commit.call_count == commit_count + assert self.conn.commit.call_count == commit_count mock_execute_batch.assert_called_once_with( - setup.cur, + self.cur, f"INSERT INTO {table} VALUES (%s)", # expected SQL [("hello",), ("world",)], # expected values page_size=1000, ) # executemany should NOT be called in this mode - setup.cur.executemany.assert_not_called() + self.cur.executemany.assert_not_called() @mock.patch("airflow.providers.postgres.hooks.postgres.send_sql_hook_lineage") @mock.patch("airflow.providers.postgres.hooks.postgres.execute_batch") - def test_insert_rows_fast_executemany_hook_lineage( - self, mock_execute_batch, mock_send_lineage, postgres_hook_setup - ): - setup = postgres_hook_setup + def test_insert_rows_fast_executemany_hook_lineage(self, mock_execute_batch, mock_send_lineage): + table = "table" rows = [("hello",), ("world",)] - setup.db_hook.insert_rows(table, rows, fast_executemany=True) + self.db_hook.insert_rows(table, rows, fast_executemany=True) mock_send_lineage.assert_called_once() call_kw = mock_send_lineage.call_args.kwargs - assert call_kw["context"] is setup.db_hook + assert call_kw["context"] is self.db_hook assert call_kw["sql"] == f"INSERT INTO {table} VALUES (%s)" assert call_kw["row_count"] == 2 - def test_insert_rows_replace(self, postgres_hook_setup): - setup = postgres_hook_setup - table = "table" - rows = [ - ( - 1, - "hello", - ), - ( - 2, - "world", - ), - ] - fields = ("id", "value") - - setup.db_hook.insert_rows(table, rows, fields, replace=True, replace_index=fields[0]) - - assert setup.conn.close.call_count == 1 - assert setup.cur.close.call_count == 1 - - commit_count = 2 # The first and last commit - assert commit_count == setup.conn.commit.call_count - - sql = ( - f"INSERT INTO {table} ({fields[0]}, {fields[1]}) VALUES (%s,%s) " - f"ON CONFLICT ({fields[0]}) DO UPDATE SET {fields[1]} = excluded.{fields[1]}" - ) - setup.cur.executemany.assert_any_call(sql, rows) - - def test_insert_rows_replace_missing_target_field_arg(self, postgres_hook_setup): - setup = postgres_hook_setup - table = "table" - rows = [ - ( - 1, - "hello", - ), - ( - 2, - "world", - ), - ] - fields = ("id", "value") - with pytest.raises(ValueError, match="PostgreSQL ON CONFLICT upsert syntax requires column names"): - setup.db_hook.insert_rows(table, rows, replace=True, replace_index=fields[0]) - - def test_insert_rows_replace_missing_replace_index_arg(self, postgres_hook_setup): - setup = postgres_hook_setup - table = "table" - rows = [ - ( - 1, - "hello", - ), - ( - 2, - "world", - ), - ] - fields = ("id", "value") - with pytest.raises(ValueError, match="PostgreSQL ON CONFLICT upsert syntax requires an unique index"): - setup.db_hook.insert_rows(table, rows, fields, replace=True) - - def test_insert_rows_replace_all_index(self, postgres_hook_setup): - setup = postgres_hook_setup - table = "table" - rows = [ - ( - 1, - "hello", - ), - ( - 2, - "world", - ), - ] - fields = ("id", "value") - - setup.db_hook.insert_rows(table, rows, fields, replace=True, replace_index=fields) - - assert setup.conn.close.call_count == 1 - assert setup.cur.close.call_count == 1 - - commit_count = 2 # The first and last commit - assert commit_count == setup.conn.commit.call_count - - sql = ( - f"INSERT INTO {table} ({', '.join(fields)}) VALUES (%s,%s) " - f"ON CONFLICT ({', '.join(fields)}) DO NOTHING" - ) - setup.cur.executemany.assert_any_call(sql, rows) - @pytest.mark.usefixtures("reset_logging_config") def test_get_all_db_log_messages(self, mocker): messages = ["a", "b", "c"] @@ -1120,40 +1092,12 @@ def test_log_db_messages_by_db_proc(self, mocker): finally: hook.run(sql=f"DROP PROCEDURE {proc_name} (s text)") - def test_dialect_name(self, postgres_hook_setup): - setup = postgres_hook_setup - assert setup.db_hook.dialect_name == "postgresql" - - def test_dialect(self, postgres_hook_setup): - setup = postgres_hook_setup - assert isinstance(setup.db_hook.dialect, PostgresDialect) - @pytest.mark.backend("postgres") @pytest.mark.skipif(not USE_PSYCOPG3, reason="psycopg v3 or sqlalchemy v2 are not available") -class TestPostgresHookPPG3: +class TestPostgresHookPPG3(_BasePostgresHookRuntimeTests): """PostgresHook tests that are specific to psycopg3.""" - table = "test_postgres_hook_table" - - def setup_method(self): - self.cur = mock.MagicMock(rowcount=0) - self.conn = conn = mock.MagicMock() - self.conn.cursor.return_value = self.cur - - class UnitTestPostgresHook(PostgresHook): - conn_name_attr = "test_conn_id" - - def get_conn(self): - return conn - - self.db_hook = UnitTestPostgresHook() - - def teardown_method(self): - with PostgresHook().get_conn() as conn: - with conn.cursor() as cur: - cur.execute(f"DROP TABLE IF EXISTS {self.table}") - def test_copy_expert_from(self, mocker): """Tests copy_expert with a 'COPY FROM STDIN' operation.""" statement = "COPY test_table FROM STDIN" @@ -1235,109 +1179,6 @@ def test_copy_expert_to(self, mocker): ) self.conn.commit.assert_called_once() - def test_insert_rows(self): - table = "table" - rows = [("hello",), ("world",)] - - self.db_hook.insert_rows(table, rows) - - assert self.conn.close.call_count == 1 - assert self.cur.close.call_count == 1 - - commit_count = 2 # The first and last commit - assert commit_count == self.conn.commit.call_count - - sql = f"INSERT INTO {table} VALUES (%s)" - self.cur.executemany.assert_any_call(sql, rows) - - def test_insert_rows_replace(self): - table = "table" - rows = [ - ( - 1, - "hello", - ), - ( - 2, - "world", - ), - ] - fields = ("id", "value") - - self.db_hook.insert_rows(table, rows, fields, replace=True, replace_index=fields[0]) - - assert self.conn.close.call_count == 1 - assert self.cur.close.call_count == 1 - - commit_count = 2 # The first and last commit - assert commit_count == self.conn.commit.call_count - - sql = ( - f"INSERT INTO {table} ({fields[0]}, {fields[1]}) VALUES (%s,%s) " - f"ON CONFLICT ({fields[0]}) DO UPDATE SET {fields[1]} = excluded.{fields[1]}" - ) - self.cur.executemany.assert_any_call(sql, rows) - - def test_insert_rows_replace_missing_target_field_arg(self): - table = "table" - rows = [ - ( - 1, - "hello", - ), - ( - 2, - "world", - ), - ] - fields = ("id", "value") - with pytest.raises(ValueError, match="PostgreSQL ON CONFLICT upsert syntax requires column names"): - self.db_hook.insert_rows(table, rows, replace=True, replace_index=fields[0]) - - def test_insert_rows_replace_missing_replace_index_arg(self): - table = "table" - rows = [ - ( - 1, - "hello", - ), - ( - 2, - "world", - ), - ] - fields = ("id", "value") - with pytest.raises(ValueError, match="PostgreSQL ON CONFLICT upsert syntax requires an unique index"): - self.db_hook.insert_rows(table, rows, fields, replace=True) - - def test_insert_rows_replace_all_index(self): - table = "table" - rows = [ - ( - 1, - "hello", - ), - ( - 2, - "world", - ), - ] - fields = ("id", "value") - - self.db_hook.insert_rows(table, rows, fields, replace=True, replace_index=fields) - - assert self.conn.close.call_count == 1 - assert self.cur.close.call_count == 1 - - commit_count = 2 # The first and last commit - assert commit_count == self.conn.commit.call_count - - sql = ( - f"INSERT INTO {table} ({', '.join(fields)}) VALUES (%s,%s) " - f"ON CONFLICT ({', '.join(fields)}) DO NOTHING" - ) - self.cur.executemany.assert_any_call(sql, rows) - @pytest.mark.skip(reason="Notice handling is callback-based in psycopg3 and cannot be tested this way.") def test_get_all_db_log_messages(self, mocker): pass @@ -1366,9 +1207,3 @@ def test_log_db_messages_by_db_proc(self, mocker): mock_logger.info.assert_any_call("Message from db: 42") finally: hook.run(sql=f"DROP PROCEDURE {proc_name} (s text)") - - def test_dialect_name(self): - assert self.db_hook.dialect_name == "postgresql" - - def test_dialect(self): - assert isinstance(self.db_hook.dialect, PostgresDialect)