diff --git a/airflow/providers/common/sql/hooks/sql.py b/airflow/providers/common/sql/hooks/sql.py index 7f1536a39b47e..4625c2e014f7a 100644 --- a/airflow/providers/common/sql/hooks/sql.py +++ b/airflow/providers/common/sql/hooks/sql.py @@ -18,7 +18,7 @@ import contextlib import warnings -from contextlib import closing +from contextlib import closing, contextmanager from datetime import datetime from typing import ( TYPE_CHECKING, @@ -147,6 +147,8 @@ class DbApiHook(BaseHook): default_conn_name = "default_conn_id" # Override if this db supports autocommit. supports_autocommit = False + # Override if this db supports executemany. + supports_executemany = False # Override with the object that exposes the connect method connector: ConnectorProtocol | None = None # Override with db-specific query to check connection @@ -408,10 +410,7 @@ def run( else: raise ValueError("List of SQL statements is empty") _last_result = None - with closing(self.get_conn()) as conn: - if self.supports_autocommit: - self.set_autocommit(conn, autocommit) - + with self._create_autocommit_connection(autocommit) as conn: with closing(conn.cursor()) as cur: results = [] for sql_statement in sql_list: @@ -528,6 +527,14 @@ def _generate_insert_sql(self, table, values, target_fields, replace, **kwargs) return self._replace_statement_format.format(table, target_fields, ",".join(placeholders)) + @contextmanager + def _create_autocommit_connection(self, autocommit: bool = False): + """Context manager that closes the connection after use and detects if autocommit is supported.""" + with closing(self.get_conn()) as conn: + if self.supports_autocommit: + self.set_autocommit(conn, autocommit) + yield conn + def insert_rows( self, table, @@ -550,47 +557,48 @@ def insert_rows( :param commit_every: The maximum number of rows to insert in one transaction. Set to 0 to insert all rows in one transaction. :param replace: Whether to replace instead of insert - :param executemany: Insert all rows at once in chunks defined by the commit_every parameter, only - works if all rows have same number of column names but leads to better performance + :param executemany: (Deprecated) If True, all rows are inserted at once in + chunks defined by the commit_every parameter. This only works if all rows + have same number of column names, but leads to better performance. """ - i = 0 - with closing(self.get_conn()) as conn: - if self.supports_autocommit: - self.set_autocommit(conn, False) + if executemany: + warnings.warn( + "executemany parameter is deprecated, override supports_executemany instead.", + AirflowProviderDeprecationWarning, + stacklevel=2, + ) + with self._create_autocommit_connection() as conn: conn.commit() - with closing(conn.cursor()) as cur: - if executemany: + if self.supports_executemany or executemany: for chunked_rows in chunked(rows, commit_every): values = list( map( - lambda row: tuple(map(lambda cell: self._serialize_cell(cell, conn), row)), + lambda row: self._serialize_cells(row, conn), chunked_rows, ) ) sql = self._generate_insert_sql(table, values[0], target_fields, replace, **kwargs) self.log.debug("Generated sql: %s", sql) - cur.fast_executemany = True cur.executemany(sql, values) conn.commit() self.log.info("Loaded %s rows into %s so far", len(chunked_rows), table) else: for i, row in enumerate(rows, 1): - lst = [] - for cell in row: - lst.append(self._serialize_cell(cell, conn)) - values = tuple(lst) + values = self._serialize_cells(row, conn) sql = self._generate_insert_sql(table, values, target_fields, replace, **kwargs) self.log.debug("Generated sql: %s", sql) cur.execute(sql, values) if commit_every and i % commit_every == 0: conn.commit() self.log.info("Loaded %s rows into %s so far", i, table) + conn.commit() + self.log.info("Done loading. Loaded a total of %s rows into %s", len(rows), table) - if not executemany: - conn.commit() - self.log.info("Done loading. Loaded a total of %s rows into %s", i, table) + @classmethod + def _serialize_cells(cls, row, conn=None): + return tuple(cls._serialize_cell(cell, conn) for cell in row) @staticmethod def _serialize_cell(cell, conn=None) -> str | None: diff --git a/airflow/providers/common/sql/hooks/sql.pyi b/airflow/providers/common/sql/hooks/sql.pyi index 83135a235bf97..85edd625f99e0 100644 --- a/airflow/providers/common/sql/hooks/sql.pyi +++ b/airflow/providers/common/sql/hooks/sql.pyi @@ -57,6 +57,7 @@ class DbApiHook(BaseHook): conn_name_attr: str default_conn_name: str supports_autocommit: bool + supports_executemany: bool connector: ConnectorProtocol | None log_sql: Incomplete descriptions: Incomplete diff --git a/airflow/providers/odbc/hooks/odbc.py b/airflow/providers/odbc/hooks/odbc.py index 8cf95bf095f31..53c4cf207a9e7 100644 --- a/airflow/providers/odbc/hooks/odbc.py +++ b/airflow/providers/odbc/hooks/odbc.py @@ -56,6 +56,7 @@ class OdbcHook(DbApiHook): conn_type = "odbc" hook_name = "ODBC" supports_autocommit = True + supports_executemany = True default_driver: str | None = None diff --git a/airflow/providers/postgres/hooks/postgres.py b/airflow/providers/postgres/hooks/postgres.py index 0afb7740fe58f..9e1b3a83d7282 100644 --- a/airflow/providers/postgres/hooks/postgres.py +++ b/airflow/providers/postgres/hooks/postgres.py @@ -74,6 +74,7 @@ class PostgresHook(DbApiHook): conn_type = "postgres" hook_name = "Postgres" supports_autocommit = True + supports_executemany = True def __init__(self, *args, options: str | None = None, **kwargs) -> None: if "schema" in kwargs: diff --git a/airflow/providers/teradata/hooks/teradata.py b/airflow/providers/teradata/hooks/teradata.py index 73c4fb8ff0523..3afc32bc74746 100644 --- a/airflow/providers/teradata/hooks/teradata.py +++ b/airflow/providers/teradata/hooks/teradata.py @@ -19,12 +19,14 @@ from __future__ import annotations +import warnings from typing import TYPE_CHECKING, Any import sqlalchemy import teradatasql from teradatasql import TeradataConnection +from airflow.exceptions import AirflowProviderDeprecationWarning from airflow.providers.common.sql.hooks.sql import DbApiHook if TYPE_CHECKING: @@ -59,6 +61,9 @@ class TeradataHook(DbApiHook): # Override if this db supports autocommit. supports_autocommit = True + # Override if this db supports executemany. + supports_executemany = True + # Override this for hook to have a custom name in the UI selection conn_type = "teradata" @@ -97,7 +102,9 @@ def bulk_insert_rows( target_fields: list[str] | None = None, commit_every: int = 5000, ): - """Insert bulk of records into Teradata SQL Database. + """Use :func:`insert_rows` instead, this is deprecated. + + Insert bulk of records into Teradata SQL Database. This uses prepared statements via `executemany()`. For best performance, pass in `rows` as an iterator. @@ -106,41 +113,20 @@ def bulk_insert_rows( specific database :param rows: the rows to insert into the table :param target_fields: the names of the columns to fill in the table, default None. - If None, each rows should have some order as table columns name + If None, each row should have some order as table columns name :param commit_every: the maximum number of rows to insert in one transaction Default 5000. Set greater than 0. Set 1 to insert each row in each transaction """ + warnings.warn( + "bulk_insert_rows is deprecated. Please use the insert_rows method instead.", + AirflowProviderDeprecationWarning, + stacklevel=2, + ) + if not rows: raise ValueError("parameter rows could not be None or empty iterable") - conn = self.get_conn() - if self.supports_autocommit: - self.set_autocommit(conn, False) - cursor = conn.cursor() - cursor.fast_executemany = True - values_base = target_fields if target_fields else rows[0] - prepared_stm = "INSERT INTO {tablename} {columns} VALUES ({values})".format( - tablename=table, - columns="({})".format(", ".join(target_fields)) if target_fields else "", - values=", ".join("?" for i in range(1, len(values_base) + 1)), - ) - row_count = 0 - # Chunk the rows - row_chunk = [] - for row in rows: - row_chunk.append(row) - row_count += 1 - if row_count % commit_every == 0: - cursor.executemany(prepared_stm, row_chunk) - conn.commit() # type: ignore[attr-defined] - # Empty chunk - row_chunk = [] - # Commit the leftover chunk - if len(row_chunk) > 0: - cursor.executemany(prepared_stm, row_chunk) - conn.commit() # type: ignore[attr-defined] - self.log.info("[%s] inserted %s rows", table, row_count) - cursor.close() - conn.close() # type: ignore[attr-defined] + + self.insert_rows(table=table, rows=rows, target_fields=target_fields, commit_every=commit_every) def _get_conn_config_teradatasql(self) -> dict[str, Any]: """Return set of config params required for connecting to Teradata DB using teradatasql client.""" diff --git a/tests/deprecations_ignore.yml b/tests/deprecations_ignore.yml index c88d2b9368cb2..c8f02a193d6d0 100644 --- a/tests/deprecations_ignore.yml +++ b/tests/deprecations_ignore.yml @@ -659,6 +659,8 @@ - tests/providers/cncf/kubernetes/test_pod_generator.py::TestPodGenerator::test_pod_name_confirm_to_max_length - tests/providers/cncf/kubernetes/test_pod_generator.py::TestPodGenerator::test_pod_name_is_valid - tests/providers/cncf/kubernetes/test_template_rendering.py::test_render_k8s_pod_yaml +- tests/providers/common/sql/hooks/test_dbapi.py::TestDbApiHook::test_insert_rows_executemany +- tests/providers/common/sql/hooks/test_dbapi.py::TestDbApiHook::test_insert_rows_replace_executemany_hana_dialect - tests/providers/common/sql/hooks/test_dbapi.py::TestDbApiHook::test_instance_check_works_for_legacy_db_api_hook - tests/providers/common/sql/operators/test_sql.py::TestSQLCheckOperatorDbHook::test_get_hook - tests/providers/common/sql/operators/test_sql.py::TestSqlBranch::test_branch_false_with_dag_run @@ -1059,6 +1061,10 @@ - tests/providers/ssh/hooks/test_ssh.py::TestSSHHook::test_tunnel_without_password - tests/providers/tableau/hooks/test_tableau.py::TestTableauHook::test_get_conn_auth_via_token_and_site_in_init - tests/providers/tableau/hooks/test_tableau.py::TestTableauHook::test_get_conn_ssl_default +- tests/providers/teradata/hooks/test_teradata.py::TestTeradataHook::test_bulk_insert_rows_with_fields +- tests/providers/teradata/hooks/test_teradata.py::TestTeradataHook::test_bulk_insert_rows_with_commit_every +- tests/providers/teradata/hooks/test_teradata.py::TestTeradataHook::test_bulk_insert_rows_without_fields +- tests/providers/teradata/hooks/test_teradata.py::TestTeradataHook::test_bulk_insert_rows_no_rows - tests/providers/trino/operators/test_trino.py::test_execute_openlineage_events - tests/providers/vertica/operators/test_vertica.py::TestVerticaOperator::test_execute - tests/providers/weaviate/operators/test_weaviate.py::TestWeaviateIngestOperator::test_constructor diff --git a/tests/providers/common/sql/hooks/test_dbapi.py b/tests/providers/common/sql/hooks/test_dbapi.py index fd9886345fd6b..2c34ee133ec71 100644 --- a/tests/providers/common/sql/hooks/test_dbapi.py +++ b/tests/providers/common/sql/hooks/test_dbapi.py @@ -21,6 +21,7 @@ from unittest import mock import pytest +from pyodbc import Cursor from airflow.hooks.base import BaseHook from airflow.models import Connection @@ -39,7 +40,7 @@ class TestDbApiHook: def setup_method(self, **kwargs): self.cur = mock.MagicMock( rowcount=0, - spec=["description", "rowcount", "execute", "executemany", "fetchall", "fetchone", "close"], + spec=Cursor, ) self.conn = mock.MagicMock() self.conn.cursor.return_value = self.cur diff --git a/tests/providers/postgres/hooks/test_postgres.py b/tests/providers/postgres/hooks/test_postgres.py index 2d62cb4f43129..8330ad3b1d53f 100644 --- a/tests/providers/postgres/hooks/test_postgres.py +++ b/tests/providers/postgres/hooks/test_postgres.py @@ -403,8 +403,7 @@ def test_insert_rows(self): assert commit_count == self.conn.commit.call_count sql = f"INSERT INTO {table} VALUES (%s)" - for row in rows: - self.cur.execute.assert_any_call(sql, row) + self.cur.executemany.assert_any_call(sql, rows) def test_insert_rows_replace(self): table = "table" @@ -432,8 +431,7 @@ def test_insert_rows_replace(self): f"INSERT INTO {table} ({fields[0]}, {fields[1]}) VALUES (%s,%s) " f"ON CONFLICT ({fields[0]}) DO UPDATE SET {fields[1]} = excluded.{fields[1]}" ) - for row in rows: - self.cur.execute.assert_any_call(sql, row) + self.cur.executemany.assert_any_call(sql, rows) def test_insert_rows_replace_missing_target_field_arg(self): table = "table" @@ -497,8 +495,7 @@ def test_insert_rows_replace_all_index(self): f"INSERT INTO {table} ({', '.join(fields)}) VALUES (%s,%s) " f"ON CONFLICT ({', '.join(fields)}) DO NOTHING" ) - for row in rows: - self.cur.execute.assert_any_call(sql, row) + self.cur.executemany.assert_any_call(sql, rows) def test_rowcount(self): hook = PostgresHook() diff --git a/tests/providers/teradata/hooks/test_teradata.py b/tests/providers/teradata/hooks/test_teradata.py index d47e987f41f22..af77d66a636fe 100644 --- a/tests/providers/teradata/hooks/test_teradata.py +++ b/tests/providers/teradata/hooks/test_teradata.py @@ -226,9 +226,9 @@ def test_insert_rows(self): "str", ] self.test_db_hook.insert_rows("table", rows, target_fields) - self.cur.execute.assert_called_once_with( + self.cur.executemany.assert_called_once_with( "INSERT INTO table (basestring, none, datetime, int, float, str) VALUES (?,?,?,?,?,?)", - ("'test_string", None, "2023-08-15T00:00:00", "1", "3.14", "str"), + [("'test_string", None, "2023-08-15T00:00:00", "1", "3.14", "str")], ) def test_bulk_insert_rows_with_fields(self): @@ -236,7 +236,8 @@ def test_bulk_insert_rows_with_fields(self): target_fields = ["col1", "col2", "col3"] self.test_db_hook.bulk_insert_rows("table", rows, target_fields) self.cur.executemany.assert_called_once_with( - "INSERT INTO table (col1, col2, col3) VALUES (?, ?, ?)", rows + "INSERT INTO table (col1, col2, col3) VALUES (?,?,?)", + [("1", "2", "3"), ("4", "5", "6"), ("7", "8", "9")], ) def test_bulk_insert_rows_with_commit_every(self): @@ -244,19 +245,20 @@ def test_bulk_insert_rows_with_commit_every(self): target_fields = ["col1", "col2", "col3"] self.test_db_hook.bulk_insert_rows("table", rows, target_fields, commit_every=2) calls = [ - mock.call("INSERT INTO table (col1, col2, col3) values (1, 2, 3)"), - mock.call("INSERT INTO table (col1, col2, col3) values (1, 2, 3)"), - ] - calls = [ - mock.call("INSERT INTO table (col1, col2, col3) VALUES (?, ?, ?)", rows[:2]), - mock.call("INSERT INTO table (col1, col2, col3) VALUES (?, ?, ?)", rows[2:]), + mock.call( + "INSERT INTO table (col1, col2, col3) VALUES (?,?,?)", [("1", "2", "3"), ("4", "5", "6")] + ), + mock.call("INSERT INTO table (col1, col2, col3) VALUES (?,?,?)", [("7", "8", "9")]), ] self.cur.executemany.assert_has_calls(calls, any_order=True) def test_bulk_insert_rows_without_fields(self): rows = [(1, 2, 3), (4, 5, 6), (7, 8, 9)] self.test_db_hook.bulk_insert_rows("table", rows) - self.cur.executemany.assert_called_once_with("INSERT INTO table VALUES (?, ?, ?)", rows) + self.cur.executemany.assert_called_once_with( + "INSERT INTO table VALUES (?,?,?)", + [("1", "2", "3"), ("4", "5", "6"), ("7", "8", "9")], + ) def test_bulk_insert_rows_no_rows(self): rows = []