Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Always use the executemany method when inserting rows in DbApiHook as it's way much faster #38715

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
69 commits
Select commit Hold shift + click to select a range
0312a8f
refactor: Always use the fast_executemany method when inserting rows …
Apr 3, 2024
219bd43
Merge branch 'main' into feature/sql-performance-enhancement-insertmany
dabla Apr 3, 2024
64b7a85
refactor: Only set fast_executemany option if explicitly defined in c…
Apr 3, 2024
c183404
refactor: Fixed tests related to insert_rows for Postgres
Apr 4, 2024
bc3895e
refactor: Fixed assertions on executemany for insert rows in Postgres…
Apr 4, 2024
3af4c7c
refactor: Deprecated bulk_insert_rows in Teradata as the insert_rows …
Apr 4, 2024
ef36ed0
refactor: Removed duplicate calls object definition in test_bulk_inse…
Apr 4, 2024
e222af5
Merge branch 'main' into feature/sql-performance-enhancement-insertmany
dabla Apr 4, 2024
d0fd763
refactor: Use DeprecationWarning instead
Apr 4, 2024
4dbdff5
Merge branch 'main' into feature/sql-performance-enhancement-insertmany
dabla Apr 4, 2024
42b7ab7
refactor: Use AirflowProviderDeprecationWarning instead
Apr 4, 2024
58d3cfa
Merge branch 'main' into feature/sql-performance-enhancement-insertmany
dabla Apr 4, 2024
04d0a47
refactor: Ignore deprecation warnings for TestTeradataHook
Apr 4, 2024
6b0ce23
Merge branch 'main' into feature/sql-performance-enhancement-insertmany
dabla Apr 4, 2024
1cb5e5e
refactor: Re-added check on rows parameter for bulk_insert_rows metho…
Apr 4, 2024
57318a9
Merge branch 'main' into feature/sql-performance-enhancement-insertmany
dabla Apr 4, 2024
4831868
Merge branch 'main' into feature/sql-performance-enhancement-insertmany
dabla Apr 4, 2024
6bb724d
Merge branch 'main' into feature/sql-performance-enhancement-insertmany
dabla Apr 4, 2024
35580fe
Merge branch 'main' into feature/sql-performance-enhancement-insertmany
dabla Apr 4, 2024
3a9fcbf
Merge branch 'main' into feature/sql-performance-enhancement-insertmany
dabla Apr 4, 2024
1e577c7
Merge branch 'main' into feature/sql-performance-enhancement-insertmany
dabla Apr 6, 2024
96f5019
Merge branch 'main' into feature/sql-performance-enhancement-insertmany
dabla Apr 8, 2024
b64351f
refactor: Fixed logging statement regarding total rows in insert_rows…
Apr 8, 2024
92775af
refactor: Changed string values back to int's but adapted expected du…
Apr 8, 2024
1369944
refactor: Added deprecation warning for executemany parameter in inse…
Apr 8, 2024
ac185d2
refactor: Reformatted mocked call for test_bulk_insert_rows_with_comm…
Apr 8, 2024
7216c3e
Merge branch 'main' into feature/sql-performance-enhancement-insertmany
dabla Apr 8, 2024
f8d8d92
refactor: Added explicit fast_executemany parameter to constructor of…
Apr 8, 2024
fee8e01
refactor: Reformatted second mocked call in test_bulk_insert_rows_wit…
Apr 8, 2024
39fe36b
refactor: Re-added executemany param to insert_rows method
Apr 8, 2024
7b65015
Merge branch 'main' into feature/sql-performance-enhancement-insertmany
dabla Apr 8, 2024
5634360
refactor: Put executemany between single quote to avoid spelling chec…
Apr 8, 2024
cdfeaf2
Revert "refactor: Added deprecation warning for executemany parameter…
Apr 8, 2024
934239a
refactor: Still support original implementation without executemany i…
Apr 8, 2024
b3ad6a6
refactor: Reformatted insert_rows method in DbApiHook
Apr 8, 2024
49e923c
fix: Fixed insert_rows method in DbApiHook and extracted common _seri…
Apr 8, 2024
4e6df4e
fix: Re-added check on fast_executemany in insert_rows method
Apr 8, 2024
a216fe9
Merge branch 'main' into feature/sql-performance-enhancement-insertmany
dabla Apr 8, 2024
82ca138
refactor: Put docstring of _closing_supporting_autocommit method on o…
Apr 8, 2024
be17da1
refactor: Added autocommit parameter which is False by default to _cl…
Apr 8, 2024
8e8b692
refactor: Added supports_executemany class variable which allows to s…
Apr 8, 2024
4d6d240
refactor: test_insert_rows should use executemany for Postgres
Apr 8, 2024
564b5ef
refactor: Deprecation warning for test_insert_rows_replace_executeman…
Apr 8, 2024
e6dd1a5
Merge branch 'main' into feature/sql-performance-enhancement-insertmany
dabla Apr 8, 2024
215d719
refactor: Deprecation warning for test_insert_rows_executemany should…
Apr 8, 2024
45b3a94
Merge branch 'main' into feature/sql-performance-enhancement-insertmany
dabla Apr 8, 2024
161f3d2
Merge branch 'main' into feature/sql-performance-enhancement-insertmany
dabla Apr 9, 2024
8566316
refactor: Removed fast_executemany
Apr 9, 2024
74e0659
Merge branch 'main' into feature/sql-performance-enhancement-insertmany
dabla Apr 9, 2024
ec16892
Merge branch 'main' into feature/sql-performance-enhancement-insertmany
dabla Apr 9, 2024
65d2ae7
Merge branch 'main' into feature/sql-performance-enhancement-insertmany
dabla Apr 9, 2024
4590409
Merge branch 'main' into feature/sql-performance-enhancement-insertmany
dabla Apr 9, 2024
f6fc1be
Merge branch 'main' into feature/sql-performance-enhancement-insertmany
dabla Apr 9, 2024
4b9e47b
Merge branch 'main' into feature/sql-performance-enhancement-insertmany
dabla Apr 10, 2024
d2b0629
Merge branch 'main' into feature/sql-performance-enhancement-insertmany
dabla Apr 10, 2024
5580451
refactor: Removed unused fast_executemany param from DbApiHook constr…
Apr 11, 2024
b242539
Merge branch 'main' into feature/sql-performance-enhancement-insertmany
dabla Apr 11, 2024
0250d49
refactor: Removed placeholder attribute from DbApiHook interface
Apr 11, 2024
e17aca5
Merge branch 'main' into feature/sql-performance-enhancement-insertmany
dabla Apr 11, 2024
15f76dc
Revert unnecessary format change
uranusjr Apr 11, 2024
7814ee3
Merge branch 'main' into feature/sql-performance-enhancement-insertmany
dabla Apr 11, 2024
65f0518
docs: Updated docstring for excutemany parameter in insert_rows method
dabla Apr 11, 2024
8850f82
Merge branch 'main' into feature/sql-performance-enhancement-insertmany
dabla Apr 11, 2024
3b029e1
refactor: Renamed _closing_supporting_autocommit method to _create_au…
Apr 11, 2024
85536d9
Merge branch 'main' into feature/sql-performance-enhancement-insertmany
dabla Apr 11, 2024
b9ccfc1
Merge branch 'main' into feature/sql-performance-enhancement-insertmany
dabla Apr 11, 2024
1d295c0
Merge branch 'main' into feature/sql-performance-enhancement-insertmany
dabla Apr 11, 2024
ad65b9b
Use comprehension instead of map
uranusjr Apr 11, 2024
1c58f99
Merge branch 'main' into feature/sql-performance-enhancement-insertmany
dabla Apr 11, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 30 additions & 22 deletions airflow/providers/common/sql/hooks/sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@

import contextlib
import warnings
from contextlib import closing
from contextlib import closing, contextmanager
from datetime import datetime
from typing import (
TYPE_CHECKING,
Expand Down Expand Up @@ -147,6 +147,8 @@ class DbApiHook(BaseHook):
default_conn_name = "default_conn_id"
# Override if this db supports autocommit.
supports_autocommit = False
# Override if this db supports executemany.
supports_executemany = False
# Override with the object that exposes the connect method
connector: ConnectorProtocol | None = None
# Override with db-specific query to check connection
Expand Down Expand Up @@ -408,10 +410,7 @@ def run(
else:
raise ValueError("List of SQL statements is empty")
_last_result = None
with closing(self.get_conn()) as conn:
if self.supports_autocommit:
self.set_autocommit(conn, autocommit)

with self._create_autocommit_connection(autocommit) as conn:
with closing(conn.cursor()) as cur:
results = []
for sql_statement in sql_list:
Expand Down Expand Up @@ -528,6 +527,14 @@ def _generate_insert_sql(self, table, values, target_fields, replace, **kwargs)

return self._replace_statement_format.format(table, target_fields, ",".join(placeholders))

@contextmanager
def _create_autocommit_connection(self, autocommit: bool = False):
"""Context manager that closes the connection after use and detects if autocommit is supported."""
with closing(self.get_conn()) as conn:
if self.supports_autocommit:
self.set_autocommit(conn, autocommit)
yield conn

def insert_rows(
self,
table,
Expand All @@ -550,47 +557,48 @@ def insert_rows(
:param commit_every: The maximum number of rows to insert in one
transaction. Set to 0 to insert all rows in one transaction.
:param replace: Whether to replace instead of insert
:param executemany: Insert all rows at once in chunks defined by the commit_every parameter, only
works if all rows have same number of column names but leads to better performance
:param executemany: (Deprecated) If True, all rows are inserted at once in
chunks defined by the commit_every parameter. This only works if all rows
have same number of column names, but leads to better performance.
"""
i = 0
with closing(self.get_conn()) as conn:
if self.supports_autocommit:
self.set_autocommit(conn, False)
if executemany:
warnings.warn(
"executemany parameter is deprecated, override supports_executemany instead.",
AirflowProviderDeprecationWarning,
stacklevel=2,
)

with self._create_autocommit_connection() as conn:
conn.commit()
dabla marked this conversation as resolved.
Show resolved Hide resolved

with closing(conn.cursor()) as cur:
if executemany:
if self.supports_executemany or executemany:
for chunked_rows in chunked(rows, commit_every):
values = list(
map(
lambda row: tuple(map(lambda cell: self._serialize_cell(cell, conn), row)),
lambda row: self._serialize_cells(row, conn),
chunked_rows,
)
)
sql = self._generate_insert_sql(table, values[0], target_fields, replace, **kwargs)
self.log.debug("Generated sql: %s", sql)
cur.fast_executemany = True
cur.executemany(sql, values)
conn.commit()
self.log.info("Loaded %s rows into %s so far", len(chunked_rows), table)
else:
for i, row in enumerate(rows, 1):
lst = []
for cell in row:
lst.append(self._serialize_cell(cell, conn))
values = tuple(lst)
values = self._serialize_cells(row, conn)
sql = self._generate_insert_sql(table, values, target_fields, replace, **kwargs)
self.log.debug("Generated sql: %s", sql)
cur.execute(sql, values)
if commit_every and i % commit_every == 0:
conn.commit()
self.log.info("Loaded %s rows into %s so far", i, table)
conn.commit()
self.log.info("Done loading. Loaded a total of %s rows into %s", len(rows), table)

if not executemany:
conn.commit()
self.log.info("Done loading. Loaded a total of %s rows into %s", i, table)
@classmethod
def _serialize_cells(cls, row, conn=None):
return tuple(cls._serialize_cell(cell, conn) for cell in row)

@staticmethod
def _serialize_cell(cell, conn=None) -> str | None:
Expand Down
1 change: 1 addition & 0 deletions airflow/providers/common/sql/hooks/sql.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ class DbApiHook(BaseHook):
conn_name_attr: str
default_conn_name: str
supports_autocommit: bool
supports_executemany: bool
connector: ConnectorProtocol | None
log_sql: Incomplete
descriptions: Incomplete
Expand Down
1 change: 1 addition & 0 deletions airflow/providers/odbc/hooks/odbc.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ class OdbcHook(DbApiHook):
conn_type = "odbc"
hook_name = "ODBC"
supports_autocommit = True
supports_executemany = True

default_driver: str | None = None

Expand Down
1 change: 1 addition & 0 deletions airflow/providers/postgres/hooks/postgres.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ class PostgresHook(DbApiHook):
conn_type = "postgres"
hook_name = "Postgres"
supports_autocommit = True
supports_executemany = True

def __init__(self, *args, options: str | None = None, **kwargs) -> None:
if "schema" in kwargs:
Expand Down
48 changes: 17 additions & 31 deletions airflow/providers/teradata/hooks/teradata.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,14 @@

from __future__ import annotations

import warnings
from typing import TYPE_CHECKING, Any

import sqlalchemy
import teradatasql
from teradatasql import TeradataConnection

from airflow.exceptions import AirflowProviderDeprecationWarning
from airflow.providers.common.sql.hooks.sql import DbApiHook

if TYPE_CHECKING:
Expand Down Expand Up @@ -59,6 +61,9 @@ class TeradataHook(DbApiHook):
# Override if this db supports autocommit.
supports_autocommit = True

# Override if this db supports executemany.
supports_executemany = True

# Override this for hook to have a custom name in the UI selection
conn_type = "teradata"

Expand Down Expand Up @@ -97,7 +102,9 @@ def bulk_insert_rows(
target_fields: list[str] | None = None,
commit_every: int = 5000,
):
"""Insert bulk of records into Teradata SQL Database.
"""Use :func:`insert_rows` instead, this is deprecated.

Insert bulk of records into Teradata SQL Database.

This uses prepared statements via `executemany()`. For best performance,
pass in `rows` as an iterator.
Expand All @@ -106,41 +113,20 @@ def bulk_insert_rows(
specific database
:param rows: the rows to insert into the table
:param target_fields: the names of the columns to fill in the table, default None.
If None, each rows should have some order as table columns name
If None, each row should have some order as table columns name
:param commit_every: the maximum number of rows to insert in one transaction
Default 5000. Set greater than 0. Set 1 to insert each row in each transaction
"""
warnings.warn(
"bulk_insert_rows is deprecated. Please use the insert_rows method instead.",
AirflowProviderDeprecationWarning,
stacklevel=2,
)

if not rows:
raise ValueError("parameter rows could not be None or empty iterable")
conn = self.get_conn()
if self.supports_autocommit:
self.set_autocommit(conn, False)
cursor = conn.cursor()
cursor.fast_executemany = True
values_base = target_fields if target_fields else rows[0]
prepared_stm = "INSERT INTO {tablename} {columns} VALUES ({values})".format(
tablename=table,
columns="({})".format(", ".join(target_fields)) if target_fields else "",
values=", ".join("?" for i in range(1, len(values_base) + 1)),
)
row_count = 0
# Chunk the rows
row_chunk = []
for row in rows:
row_chunk.append(row)
row_count += 1
if row_count % commit_every == 0:
cursor.executemany(prepared_stm, row_chunk)
conn.commit() # type: ignore[attr-defined]
# Empty chunk
row_chunk = []
# Commit the leftover chunk
if len(row_chunk) > 0:
cursor.executemany(prepared_stm, row_chunk)
conn.commit() # type: ignore[attr-defined]
self.log.info("[%s] inserted %s rows", table, row_count)
cursor.close()
conn.close() # type: ignore[attr-defined]

self.insert_rows(table=table, rows=rows, target_fields=target_fields, commit_every=commit_every)

def _get_conn_config_teradatasql(self) -> dict[str, Any]:
"""Return set of config params required for connecting to Teradata DB using teradatasql client."""
Expand Down
6 changes: 6 additions & 0 deletions tests/deprecations_ignore.yml
Original file line number Diff line number Diff line change
Expand Up @@ -659,6 +659,8 @@
- tests/providers/cncf/kubernetes/test_pod_generator.py::TestPodGenerator::test_pod_name_confirm_to_max_length
- tests/providers/cncf/kubernetes/test_pod_generator.py::TestPodGenerator::test_pod_name_is_valid
- tests/providers/cncf/kubernetes/test_template_rendering.py::test_render_k8s_pod_yaml
- tests/providers/common/sql/hooks/test_dbapi.py::TestDbApiHook::test_insert_rows_executemany
- tests/providers/common/sql/hooks/test_dbapi.py::TestDbApiHook::test_insert_rows_replace_executemany_hana_dialect
- tests/providers/common/sql/hooks/test_dbapi.py::TestDbApiHook::test_instance_check_works_for_legacy_db_api_hook
- tests/providers/common/sql/operators/test_sql.py::TestSQLCheckOperatorDbHook::test_get_hook
- tests/providers/common/sql/operators/test_sql.py::TestSqlBranch::test_branch_false_with_dag_run
Expand Down Expand Up @@ -1059,6 +1061,10 @@
- tests/providers/ssh/hooks/test_ssh.py::TestSSHHook::test_tunnel_without_password
- tests/providers/tableau/hooks/test_tableau.py::TestTableauHook::test_get_conn_auth_via_token_and_site_in_init
- tests/providers/tableau/hooks/test_tableau.py::TestTableauHook::test_get_conn_ssl_default
- tests/providers/teradata/hooks/test_teradata.py::TestTeradataHook::test_bulk_insert_rows_with_fields
- tests/providers/teradata/hooks/test_teradata.py::TestTeradataHook::test_bulk_insert_rows_with_commit_every
- tests/providers/teradata/hooks/test_teradata.py::TestTeradataHook::test_bulk_insert_rows_without_fields
- tests/providers/teradata/hooks/test_teradata.py::TestTeradataHook::test_bulk_insert_rows_no_rows
- tests/providers/trino/operators/test_trino.py::test_execute_openlineage_events
- tests/providers/vertica/operators/test_vertica.py::TestVerticaOperator::test_execute
- tests/providers/weaviate/operators/test_weaviate.py::TestWeaviateIngestOperator::test_constructor
Expand Down
3 changes: 2 additions & 1 deletion tests/providers/common/sql/hooks/test_dbapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from unittest import mock

import pytest
from pyodbc import Cursor

from airflow.hooks.base import BaseHook
from airflow.models import Connection
Expand All @@ -39,7 +40,7 @@ class TestDbApiHook:
def setup_method(self, **kwargs):
self.cur = mock.MagicMock(
rowcount=0,
spec=["description", "rowcount", "execute", "executemany", "fetchall", "fetchone", "close"],
spec=Cursor,
)
self.conn = mock.MagicMock()
self.conn.cursor.return_value = self.cur
Expand Down
9 changes: 3 additions & 6 deletions tests/providers/postgres/hooks/test_postgres.py
Original file line number Diff line number Diff line change
Expand Up @@ -403,8 +403,7 @@ def test_insert_rows(self):
assert commit_count == self.conn.commit.call_count

sql = f"INSERT INTO {table} VALUES (%s)"
for row in rows:
self.cur.execute.assert_any_call(sql, row)
self.cur.executemany.assert_any_call(sql, rows)

def test_insert_rows_replace(self):
table = "table"
Expand Down Expand Up @@ -432,8 +431,7 @@ def test_insert_rows_replace(self):
f"INSERT INTO {table} ({fields[0]}, {fields[1]}) VALUES (%s,%s) "
f"ON CONFLICT ({fields[0]}) DO UPDATE SET {fields[1]} = excluded.{fields[1]}"
)
for row in rows:
self.cur.execute.assert_any_call(sql, row)
self.cur.executemany.assert_any_call(sql, rows)

def test_insert_rows_replace_missing_target_field_arg(self):
table = "table"
Expand Down Expand Up @@ -497,8 +495,7 @@ def test_insert_rows_replace_all_index(self):
f"INSERT INTO {table} ({', '.join(fields)}) VALUES (%s,%s) "
f"ON CONFLICT ({', '.join(fields)}) DO NOTHING"
)
for row in rows:
self.cur.execute.assert_any_call(sql, row)
self.cur.executemany.assert_any_call(sql, rows)

def test_rowcount(self):
hook = PostgresHook()
Expand Down
22 changes: 12 additions & 10 deletions tests/providers/teradata/hooks/test_teradata.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,37 +226,39 @@ def test_insert_rows(self):
"str",
]
self.test_db_hook.insert_rows("table", rows, target_fields)
self.cur.execute.assert_called_once_with(
self.cur.executemany.assert_called_once_with(
"INSERT INTO table (basestring, none, datetime, int, float, str) VALUES (?,?,?,?,?,?)",
("'test_string", None, "2023-08-15T00:00:00", "1", "3.14", "str"),
[("'test_string", None, "2023-08-15T00:00:00", "1", "3.14", "str")],
)

def test_bulk_insert_rows_with_fields(self):
rows = [(1, 2, 3), (4, 5, 6), (7, 8, 9)]
target_fields = ["col1", "col2", "col3"]
self.test_db_hook.bulk_insert_rows("table", rows, target_fields)
self.cur.executemany.assert_called_once_with(
"INSERT INTO table (col1, col2, col3) VALUES (?, ?, ?)", rows
"INSERT INTO table (col1, col2, col3) VALUES (?,?,?)",
[("1", "2", "3"), ("4", "5", "6"), ("7", "8", "9")],
)

def test_bulk_insert_rows_with_commit_every(self):
rows = [(1, 2, 3), (4, 5, 6), (7, 8, 9)]
target_fields = ["col1", "col2", "col3"]
self.test_db_hook.bulk_insert_rows("table", rows, target_fields, commit_every=2)
calls = [
mock.call("INSERT INTO table (col1, col2, col3) values (1, 2, 3)"),
mock.call("INSERT INTO table (col1, col2, col3) values (1, 2, 3)"),
]
calls = [
mock.call("INSERT INTO table (col1, col2, col3) VALUES (?, ?, ?)", rows[:2]),
mock.call("INSERT INTO table (col1, col2, col3) VALUES (?, ?, ?)", rows[2:]),
mock.call(
"INSERT INTO table (col1, col2, col3) VALUES (?,?,?)", [("1", "2", "3"), ("4", "5", "6")]
),
mock.call("INSERT INTO table (col1, col2, col3) VALUES (?,?,?)", [("7", "8", "9")]),
]
self.cur.executemany.assert_has_calls(calls, any_order=True)

def test_bulk_insert_rows_without_fields(self):
rows = [(1, 2, 3), (4, 5, 6), (7, 8, 9)]
self.test_db_hook.bulk_insert_rows("table", rows)
self.cur.executemany.assert_called_once_with("INSERT INTO table VALUES (?, ?, ?)", rows)
self.cur.executemany.assert_called_once_with(
"INSERT INTO table VALUES (?,?,?)",
[("1", "2", "3"), ("4", "5", "6"), ("7", "8", "9")],
)

def test_bulk_insert_rows_no_rows(self):
rows = []
Expand Down