Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add optional result handler to database hooks #15581

Merged
merged 6 commits into from
May 17, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.txt
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ New Features
Improvements
""""""""""""

- Add optional result handler callback to ``DbApiHook`` (#15581)
- Update Flask App Builder limit to recently released 3.3 (#15792)
- Prevent creating flask sessions on REST API requests (#15295)
- Sync DAG specific permissions when parsing (#15311)
Expand Down
41 changes: 31 additions & 10 deletions airflow/hooks/dbapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,7 @@ def get_first(self, sql, parameters=None):
cur.execute(sql)
return cur.fetchone()

def run(self, sql, autocommit=False, parameters=None):
def run(self, sql, autocommit=False, parameters=None, handler=None):
"""
Runs a command or a list of commands. Pass a list of sql
statements to the sql parameter to get them to execute
Expand All @@ -166,30 +166,51 @@ def run(self, sql, autocommit=False, parameters=None):
:type autocommit: bool
:param parameters: The parameters to render the SQL query with.
:type parameters: dict or iterable
:param handler: The result handler which is called with the result of each statement.
:type handler: callable
:return: query results if handler was provided.
"""
if isinstance(sql, str):
scalar = isinstance(sql, str)
if scalar:
sql = [sql]

with closing(self.get_conn()) as conn:
if self.supports_autocommit:
self.set_autocommit(conn, autocommit)

with closing(conn.cursor()) as cur:
results = []
for sql_statement in sql:

self.log.info("Running statement: %s, parameters: %s", sql_statement, parameters)
if parameters:
cur.execute(sql_statement, parameters)
else:
cur.execute(sql_statement)
if hasattr(cur, 'rowcount'):
self.log.info("Rows affected: %s", cur.rowcount)
self._run_command(cur, sql_statement, parameters)
if handler is not None:
result = handler(cur)
results.append(result)

# If autocommit was set to False for db that supports autocommit,
# or if db does not supports autocommit, we do a manual commit.
if not self.get_autocommit(conn):
conn.commit()

if handler is None:
return None

if scalar:
return results[0]

return results

def _run_command(self, cur, sql_statement, parameters):
"""Runs a statement using an already open cursor."""
self.log.info("Running statement: %s, parameters: %s", sql_statement, parameters)
if parameters:
cur.execute(sql_statement, parameters)
else:
cur.execute(sql_statement)

# According to PEP 249, this is -1 when query result is not applicable.
if cur.rowcount >= 0:
self.log.info("Rows affected: %s", cur.rowcount)

def set_autocommit(self, conn, autocommit):
"""Sets the autocommit flag on the connection"""
if not self.supports_autocommit and autocommit:
Expand Down
2 changes: 1 addition & 1 deletion airflow/providers/oracle/operators/oracle.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@

class OracleOperator(BaseOperator):
"""
Executes sql code in a specific Oracle database
Executes sql code in a specific Oracle database.

:param sql: the sql code to be executed. Can receive a str representing a sql statement,
a list of str (sql statements), or reference to a template file.
Expand Down
36 changes: 35 additions & 1 deletion tests/hooks/test_dbapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ class TestDbApiHook(unittest.TestCase):
def setUp(self):
super().setUp()

self.cur = mock.MagicMock()
self.cur = mock.MagicMock(rowcount=0)
self.conn = mock.MagicMock()
self.conn.cursor.return_value = self.cur
conn = self.conn
Expand Down Expand Up @@ -184,3 +184,37 @@ def test_run_log(self):
statement = 'SQL'
self.db_hook.run(statement)
assert self.db_hook.log.info.call_count == 2

def test_run_with_handler(self):
sql = 'SQL'
param = ('p1', 'p2')
called = 0
obj = object()

def handler(cur):
cur.execute.assert_called_once_with(sql, param)
nonlocal called
called += 1
return obj

result = self.db_hook.run(sql, parameters=param, handler=handler)
assert called == 1
assert self.conn.commit.called
assert result == obj

def test_run_with_handler_multiple(self):
sql = ['SQL', 'SQL']
param = ('p1', 'p2')
called = 0
obj = object()

def handler(cur):
cur.execute.assert_called_with(sql[0], param)
nonlocal called
called += 1
return obj

result = self.db_hook.run(sql, parameters=param, handler=handler)
assert called == 2
assert self.conn.commit.called
assert result == [obj, obj]
2 changes: 1 addition & 1 deletion tests/providers/apache/druid/hooks/test_druid.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,7 +192,7 @@ def test_get_auth_with_no_user_and_password(self, mock_get_connection):
class TestDruidDbApiHook(unittest.TestCase):
def setUp(self):
super().setUp()
self.cur = MagicMock()
self.cur = MagicMock(rowcount=0)
self.conn = conn = MagicMock()
self.conn.host = 'host'
self.conn.port = '1000'
Expand Down
2 changes: 1 addition & 1 deletion tests/providers/apache/pinot/hooks/test_pinot.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,7 +212,7 @@ def setUp(self):
self.conn.port = '1000'
self.conn.conn_type = 'http'
self.conn.extra_dejson = {'endpoint': 'query/sql'}
self.cur = mock.MagicMock()
self.cur = mock.MagicMock(rowcount=0)
self.conn.cursor.return_value = self.cur
self.conn.__enter__.return_value = self.cur
self.conn.__exit__.return_value = None
Expand Down
2 changes: 1 addition & 1 deletion tests/providers/elasticsearch/hooks/test_elasticsearch.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ class TestElasticsearchHook(unittest.TestCase):
def setUp(self):
super().setUp()

self.cur = mock.MagicMock()
self.cur = mock.MagicMock(rowcount=0)
self.conn = mock.MagicMock()
self.conn.cursor.return_value = self.cur
conn = self.conn
Expand Down
2 changes: 1 addition & 1 deletion tests/providers/exasol/hooks/test_exasol.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ class TestExasolHook(unittest.TestCase):
def setUp(self):
super().setUp()

self.cur = mock.MagicMock()
self.cur = mock.MagicMock(rowcount=0)
self.conn = mock.MagicMock()
self.conn.execute.return_value = self.cur
conn = self.conn
Expand Down
2 changes: 1 addition & 1 deletion tests/providers/mysql/hooks/test_mysql.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,7 +236,7 @@ class TestMySqlHook(unittest.TestCase):
def setUp(self):
super().setUp()

self.cur = mock.MagicMock()
self.cur = mock.MagicMock(rowcount=0)
self.conn = mock.MagicMock()
self.conn.cursor.return_value = self.cur
conn = self.conn
Expand Down
2 changes: 1 addition & 1 deletion tests/providers/oracle/hooks/test_oracle.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,7 @@ class TestOracleHook(unittest.TestCase):
def setUp(self):
super().setUp()

self.cur = mock.MagicMock()
self.cur = mock.MagicMock(rowcount=0)
self.conn = mock.MagicMock()
self.conn.cursor.return_value = self.cur
conn = self.conn
Expand Down
10 changes: 7 additions & 3 deletions tests/providers/oracle/operators/test_oracle.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@


class TestOracleOperator(unittest.TestCase):
@mock.patch.object(OracleHook, 'run')
@mock.patch.object(OracleHook, 'run', autospec=OracleHook.run)
def test_execute(self, mock_run):
sql = 'SELECT * FROM test_table'
oracle_conn_id = 'oracle_default'
Expand All @@ -40,5 +40,9 @@ def test_execute(self, mock_run):
task_id=task_id,
)
operator.execute(context=context)

mock_run.assert_called_once_with(sql, autocommit=autocommit, parameters=parameters)
mock_run.assert_called_once_with(
mock.ANY,
sql,
autocommit=autocommit,
parameters=parameters,
)
2 changes: 1 addition & 1 deletion tests/providers/postgres/hooks/test_postgres.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,7 @@ def __init__(self, *args, **kwargs):
def setUp(self):
super().setUp()

self.cur = mock.MagicMock()
self.cur = mock.MagicMock(rowcount=0)
self.conn = conn = mock.MagicMock()
self.conn.cursor.return_value = self.cur

Expand Down
2 changes: 1 addition & 1 deletion tests/providers/presto/hooks/test_presto.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,7 @@ class TestPrestoHook(unittest.TestCase):
def setUp(self):
super().setUp()

self.cur = mock.MagicMock()
self.cur = mock.MagicMock(rowcount=0)
self.conn = mock.MagicMock()
self.conn.cursor.return_value = self.cur
conn = self.conn
Expand Down
4 changes: 2 additions & 2 deletions tests/providers/snowflake/hooks/test_snowflake.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,8 @@ class TestSnowflakeHook(unittest.TestCase):
def setUp(self):
super().setUp()

self.cur = mock.MagicMock()
self.cur2 = mock.MagicMock()
self.cur = mock.MagicMock(rowcount=0)
self.cur2 = mock.MagicMock(rowcount=0)

self.cur.sfqid = 'uuid'
self.cur2.sfqid = 'uuid2'
Expand Down
2 changes: 1 addition & 1 deletion tests/providers/sqlite/hooks/test_sqlite.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def test_get_conn_non_default_id(self, mock_connect):
class TestSqliteHook(unittest.TestCase):
def setUp(self):

self.cur = mock.MagicMock()
self.cur = mock.MagicMock(rowcount=0)
self.conn = mock.MagicMock()
self.conn.cursor.return_value = self.cur
conn = self.conn
Expand Down
2 changes: 1 addition & 1 deletion tests/providers/trino/hooks/test_trino.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,7 @@ class TestTrinoHook(unittest.TestCase):
def setUp(self):
super().setUp()

self.cur = mock.MagicMock()
self.cur = mock.MagicMock(rowcount=0)
self.conn = mock.MagicMock()
self.conn.cursor.return_value = self.cur
conn = self.conn
Expand Down
2 changes: 1 addition & 1 deletion tests/providers/vertica/hooks/test_vertica.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ class TestVerticaHook(unittest.TestCase):
def setUp(self):
super().setUp()

self.cur = mock.MagicMock()
self.cur = mock.MagicMock(rowcount=0)
self.conn = mock.MagicMock()
self.conn.cursor.return_value = self.cur
conn = self.conn
Expand Down