Skip to content

Commit

Permalink
Always use the executemany method when inserting rows in DbApiHook as…
Browse files Browse the repository at this point in the history
… it's way much faster (#38715)


---------

Co-authored-by: David Blain <david.blain@infrabel.be>
Co-authored-by: Tzu-ping Chung <uranusjr@gmail.com>
  • Loading branch information
3 people authored Apr 12, 2024
1 parent 53dcbce commit 7ab24c7
Show file tree
Hide file tree
Showing 9 changed files with 73 additions and 70 deletions.
52 changes: 30 additions & 22 deletions airflow/providers/common/sql/hooks/sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@

import contextlib
import warnings
from contextlib import closing
from contextlib import closing, contextmanager
from datetime import datetime
from typing import (
TYPE_CHECKING,
Expand Down Expand Up @@ -147,6 +147,8 @@ class DbApiHook(BaseHook):
default_conn_name = "default_conn_id"
# Override if this db supports autocommit.
supports_autocommit = False
# Override if this db supports executemany.
supports_executemany = False
# Override with the object that exposes the connect method
connector: ConnectorProtocol | None = None
# Override with db-specific query to check connection
Expand Down Expand Up @@ -408,10 +410,7 @@ def run(
else:
raise ValueError("List of SQL statements is empty")
_last_result = None
with closing(self.get_conn()) as conn:
if self.supports_autocommit:
self.set_autocommit(conn, autocommit)

with self._create_autocommit_connection(autocommit) as conn:
with closing(conn.cursor()) as cur:
results = []
for sql_statement in sql_list:
Expand Down Expand Up @@ -528,6 +527,14 @@ def _generate_insert_sql(self, table, values, target_fields, replace, **kwargs)

return self._replace_statement_format.format(table, target_fields, ",".join(placeholders))

@contextmanager
def _create_autocommit_connection(self, autocommit: bool = False):
"""Context manager that closes the connection after use and detects if autocommit is supported."""
with closing(self.get_conn()) as conn:
if self.supports_autocommit:
self.set_autocommit(conn, autocommit)
yield conn

def insert_rows(
self,
table,
Expand All @@ -550,47 +557,48 @@ def insert_rows(
:param commit_every: The maximum number of rows to insert in one
transaction. Set to 0 to insert all rows in one transaction.
:param replace: Whether to replace instead of insert
:param executemany: Insert all rows at once in chunks defined by the commit_every parameter, only
works if all rows have same number of column names but leads to better performance
:param executemany: (Deprecated) If True, all rows are inserted at once in
chunks defined by the commit_every parameter. This only works if all rows
have same number of column names, but leads to better performance.
"""
i = 0
with closing(self.get_conn()) as conn:
if self.supports_autocommit:
self.set_autocommit(conn, False)
if executemany:
warnings.warn(
"executemany parameter is deprecated, override supports_executemany instead.",
AirflowProviderDeprecationWarning,
stacklevel=2,
)

with self._create_autocommit_connection() as conn:
conn.commit()

with closing(conn.cursor()) as cur:
if executemany:
if self.supports_executemany or executemany:
for chunked_rows in chunked(rows, commit_every):
values = list(
map(
lambda row: tuple(map(lambda cell: self._serialize_cell(cell, conn), row)),
lambda row: self._serialize_cells(row, conn),
chunked_rows,
)
)
sql = self._generate_insert_sql(table, values[0], target_fields, replace, **kwargs)
self.log.debug("Generated sql: %s", sql)
cur.fast_executemany = True
cur.executemany(sql, values)
conn.commit()
self.log.info("Loaded %s rows into %s so far", len(chunked_rows), table)
else:
for i, row in enumerate(rows, 1):
lst = []
for cell in row:
lst.append(self._serialize_cell(cell, conn))
values = tuple(lst)
values = self._serialize_cells(row, conn)
sql = self._generate_insert_sql(table, values, target_fields, replace, **kwargs)
self.log.debug("Generated sql: %s", sql)
cur.execute(sql, values)
if commit_every and i % commit_every == 0:
conn.commit()
self.log.info("Loaded %s rows into %s so far", i, table)
conn.commit()
self.log.info("Done loading. Loaded a total of %s rows into %s", len(rows), table)

if not executemany:
conn.commit()
self.log.info("Done loading. Loaded a total of %s rows into %s", i, table)
@classmethod
def _serialize_cells(cls, row, conn=None):
return tuple(cls._serialize_cell(cell, conn) for cell in row)

@staticmethod
def _serialize_cell(cell, conn=None) -> str | None:
Expand Down
1 change: 1 addition & 0 deletions airflow/providers/common/sql/hooks/sql.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ class DbApiHook(BaseHook):
conn_name_attr: str
default_conn_name: str
supports_autocommit: bool
supports_executemany: bool
connector: ConnectorProtocol | None
log_sql: Incomplete
descriptions: Incomplete
Expand Down
1 change: 1 addition & 0 deletions airflow/providers/odbc/hooks/odbc.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ class OdbcHook(DbApiHook):
conn_type = "odbc"
hook_name = "ODBC"
supports_autocommit = True
supports_executemany = True

default_driver: str | None = None

Expand Down
1 change: 1 addition & 0 deletions airflow/providers/postgres/hooks/postgres.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ class PostgresHook(DbApiHook):
conn_type = "postgres"
hook_name = "Postgres"
supports_autocommit = True
supports_executemany = True

def __init__(self, *args, options: str | None = None, **kwargs) -> None:
if "schema" in kwargs:
Expand Down
48 changes: 17 additions & 31 deletions airflow/providers/teradata/hooks/teradata.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,14 @@

from __future__ import annotations

import warnings
from typing import TYPE_CHECKING, Any

import sqlalchemy
import teradatasql
from teradatasql import TeradataConnection

from airflow.exceptions import AirflowProviderDeprecationWarning
from airflow.providers.common.sql.hooks.sql import DbApiHook

if TYPE_CHECKING:
Expand Down Expand Up @@ -59,6 +61,9 @@ class TeradataHook(DbApiHook):
# Override if this db supports autocommit.
supports_autocommit = True

# Override if this db supports executemany.
supports_executemany = True

# Override this for hook to have a custom name in the UI selection
conn_type = "teradata"

Expand Down Expand Up @@ -97,7 +102,9 @@ def bulk_insert_rows(
target_fields: list[str] | None = None,
commit_every: int = 5000,
):
"""Insert bulk of records into Teradata SQL Database.
"""Use :func:`insert_rows` instead, this is deprecated.
Insert bulk of records into Teradata SQL Database.
This uses prepared statements via `executemany()`. For best performance,
pass in `rows` as an iterator.
Expand All @@ -106,41 +113,20 @@ def bulk_insert_rows(
specific database
:param rows: the rows to insert into the table
:param target_fields: the names of the columns to fill in the table, default None.
If None, each rows should have some order as table columns name
If None, each row should have some order as table columns name
:param commit_every: the maximum number of rows to insert in one transaction
Default 5000. Set greater than 0. Set 1 to insert each row in each transaction
"""
warnings.warn(
"bulk_insert_rows is deprecated. Please use the insert_rows method instead.",
AirflowProviderDeprecationWarning,
stacklevel=2,
)

if not rows:
raise ValueError("parameter rows could not be None or empty iterable")
conn = self.get_conn()
if self.supports_autocommit:
self.set_autocommit(conn, False)
cursor = conn.cursor()
cursor.fast_executemany = True
values_base = target_fields if target_fields else rows[0]
prepared_stm = "INSERT INTO {tablename} {columns} VALUES ({values})".format(
tablename=table,
columns="({})".format(", ".join(target_fields)) if target_fields else "",
values=", ".join("?" for i in range(1, len(values_base) + 1)),
)
row_count = 0
# Chunk the rows
row_chunk = []
for row in rows:
row_chunk.append(row)
row_count += 1
if row_count % commit_every == 0:
cursor.executemany(prepared_stm, row_chunk)
conn.commit() # type: ignore[attr-defined]
# Empty chunk
row_chunk = []
# Commit the leftover chunk
if len(row_chunk) > 0:
cursor.executemany(prepared_stm, row_chunk)
conn.commit() # type: ignore[attr-defined]
self.log.info("[%s] inserted %s rows", table, row_count)
cursor.close()
conn.close() # type: ignore[attr-defined]

self.insert_rows(table=table, rows=rows, target_fields=target_fields, commit_every=commit_every)

def _get_conn_config_teradatasql(self) -> dict[str, Any]:
"""Return set of config params required for connecting to Teradata DB using teradatasql client."""
Expand Down
6 changes: 6 additions & 0 deletions tests/deprecations_ignore.yml
Original file line number Diff line number Diff line change
Expand Up @@ -659,6 +659,8 @@
- tests/providers/cncf/kubernetes/test_pod_generator.py::TestPodGenerator::test_pod_name_confirm_to_max_length
- tests/providers/cncf/kubernetes/test_pod_generator.py::TestPodGenerator::test_pod_name_is_valid
- tests/providers/cncf/kubernetes/test_template_rendering.py::test_render_k8s_pod_yaml
- tests/providers/common/sql/hooks/test_dbapi.py::TestDbApiHook::test_insert_rows_executemany
- tests/providers/common/sql/hooks/test_dbapi.py::TestDbApiHook::test_insert_rows_replace_executemany_hana_dialect
- tests/providers/common/sql/hooks/test_dbapi.py::TestDbApiHook::test_instance_check_works_for_legacy_db_api_hook
- tests/providers/common/sql/operators/test_sql.py::TestSQLCheckOperatorDbHook::test_get_hook
- tests/providers/common/sql/operators/test_sql.py::TestSqlBranch::test_branch_false_with_dag_run
Expand Down Expand Up @@ -1059,6 +1061,10 @@
- tests/providers/ssh/hooks/test_ssh.py::TestSSHHook::test_tunnel_without_password
- tests/providers/tableau/hooks/test_tableau.py::TestTableauHook::test_get_conn_auth_via_token_and_site_in_init
- tests/providers/tableau/hooks/test_tableau.py::TestTableauHook::test_get_conn_ssl_default
- tests/providers/teradata/hooks/test_teradata.py::TestTeradataHook::test_bulk_insert_rows_with_fields
- tests/providers/teradata/hooks/test_teradata.py::TestTeradataHook::test_bulk_insert_rows_with_commit_every
- tests/providers/teradata/hooks/test_teradata.py::TestTeradataHook::test_bulk_insert_rows_without_fields
- tests/providers/teradata/hooks/test_teradata.py::TestTeradataHook::test_bulk_insert_rows_no_rows
- tests/providers/trino/operators/test_trino.py::test_execute_openlineage_events
- tests/providers/vertica/operators/test_vertica.py::TestVerticaOperator::test_execute
- tests/providers/weaviate/operators/test_weaviate.py::TestWeaviateIngestOperator::test_constructor
Expand Down
3 changes: 2 additions & 1 deletion tests/providers/common/sql/hooks/test_dbapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from unittest import mock

import pytest
from pyodbc import Cursor

from airflow.hooks.base import BaseHook
from airflow.models import Connection
Expand All @@ -39,7 +40,7 @@ class TestDbApiHook:
def setup_method(self, **kwargs):
self.cur = mock.MagicMock(
rowcount=0,
spec=["description", "rowcount", "execute", "executemany", "fetchall", "fetchone", "close"],
spec=Cursor,
)
self.conn = mock.MagicMock()
self.conn.cursor.return_value = self.cur
Expand Down
9 changes: 3 additions & 6 deletions tests/providers/postgres/hooks/test_postgres.py
Original file line number Diff line number Diff line change
Expand Up @@ -403,8 +403,7 @@ def test_insert_rows(self):
assert commit_count == self.conn.commit.call_count

sql = f"INSERT INTO {table} VALUES (%s)"
for row in rows:
self.cur.execute.assert_any_call(sql, row)
self.cur.executemany.assert_any_call(sql, rows)

def test_insert_rows_replace(self):
table = "table"
Expand Down Expand Up @@ -432,8 +431,7 @@ def test_insert_rows_replace(self):
f"INSERT INTO {table} ({fields[0]}, {fields[1]}) VALUES (%s,%s) "
f"ON CONFLICT ({fields[0]}) DO UPDATE SET {fields[1]} = excluded.{fields[1]}"
)
for row in rows:
self.cur.execute.assert_any_call(sql, row)
self.cur.executemany.assert_any_call(sql, rows)

def test_insert_rows_replace_missing_target_field_arg(self):
table = "table"
Expand Down Expand Up @@ -497,8 +495,7 @@ def test_insert_rows_replace_all_index(self):
f"INSERT INTO {table} ({', '.join(fields)}) VALUES (%s,%s) "
f"ON CONFLICT ({', '.join(fields)}) DO NOTHING"
)
for row in rows:
self.cur.execute.assert_any_call(sql, row)
self.cur.executemany.assert_any_call(sql, rows)

def test_rowcount(self):
hook = PostgresHook()
Expand Down
22 changes: 12 additions & 10 deletions tests/providers/teradata/hooks/test_teradata.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,37 +226,39 @@ def test_insert_rows(self):
"str",
]
self.test_db_hook.insert_rows("table", rows, target_fields)
self.cur.execute.assert_called_once_with(
self.cur.executemany.assert_called_once_with(
"INSERT INTO table (basestring, none, datetime, int, float, str) VALUES (?,?,?,?,?,?)",
("'test_string", None, "2023-08-15T00:00:00", "1", "3.14", "str"),
[("'test_string", None, "2023-08-15T00:00:00", "1", "3.14", "str")],
)

def test_bulk_insert_rows_with_fields(self):
rows = [(1, 2, 3), (4, 5, 6), (7, 8, 9)]
target_fields = ["col1", "col2", "col3"]
self.test_db_hook.bulk_insert_rows("table", rows, target_fields)
self.cur.executemany.assert_called_once_with(
"INSERT INTO table (col1, col2, col3) VALUES (?, ?, ?)", rows
"INSERT INTO table (col1, col2, col3) VALUES (?,?,?)",
[("1", "2", "3"), ("4", "5", "6"), ("7", "8", "9")],
)

def test_bulk_insert_rows_with_commit_every(self):
rows = [(1, 2, 3), (4, 5, 6), (7, 8, 9)]
target_fields = ["col1", "col2", "col3"]
self.test_db_hook.bulk_insert_rows("table", rows, target_fields, commit_every=2)
calls = [
mock.call("INSERT INTO table (col1, col2, col3) values (1, 2, 3)"),
mock.call("INSERT INTO table (col1, col2, col3) values (1, 2, 3)"),
]
calls = [
mock.call("INSERT INTO table (col1, col2, col3) VALUES (?, ?, ?)", rows[:2]),
mock.call("INSERT INTO table (col1, col2, col3) VALUES (?, ?, ?)", rows[2:]),
mock.call(
"INSERT INTO table (col1, col2, col3) VALUES (?,?,?)", [("1", "2", "3"), ("4", "5", "6")]
),
mock.call("INSERT INTO table (col1, col2, col3) VALUES (?,?,?)", [("7", "8", "9")]),
]
self.cur.executemany.assert_has_calls(calls, any_order=True)

def test_bulk_insert_rows_without_fields(self):
rows = [(1, 2, 3), (4, 5, 6), (7, 8, 9)]
self.test_db_hook.bulk_insert_rows("table", rows)
self.cur.executemany.assert_called_once_with("INSERT INTO table VALUES (?, ?, ?)", rows)
self.cur.executemany.assert_called_once_with(
"INSERT INTO table VALUES (?,?,?)",
[("1", "2", "3"), ("4", "5", "6"), ("7", "8", "9")],
)

def test_bulk_insert_rows_no_rows(self):
rows = []
Expand Down

0 comments on commit 7ab24c7

Please sign in to comment.