Skip to content

Commit

Permalink
Update Exasol to common DBApiHook semantics and add tests (#28009)
Browse files Browse the repository at this point in the history
The exasol hook now uses the same semantics as all other DbApi
Hook. Since (for now) it has a separate run() method, it also
has a comprehensive tests now covering all kinds of combinations
of parmeters.
  • Loading branch information
potiuk committed Nov 30, 2022
1 parent 8924cf1 commit 430e930
Show file tree
Hide file tree
Showing 3 changed files with 391 additions and 5 deletions.
18 changes: 13 additions & 5 deletions airflow/providers/exasol/hooks/exasol.py
Expand Up @@ -157,19 +157,21 @@ def run(
:param return_last: Whether to return result for only last statement or for all after split
:return: return only result of the LAST SQL expression if handler was provided.
"""
self.descriptions = []
if isinstance(sql, str):
if split_statements:
sql_list: Iterable[str] = self.split_sql_string(sql)
else:
sql_list = [self.strip_sql_string(sql)]
statement = self.strip_sql_string(sql)
sql_list = [statement] if statement.strip() else []
else:
sql_list = sql

if sql_list:
self.log.debug("Executing following statements against Exasol DB: %s", list(sql_list))
else:
raise ValueError("List of SQL statements is empty")

_last_result = None
with closing(self.get_conn()) as conn:
self.set_autocommit(conn, autocommit)
results = []
Expand All @@ -178,7 +180,12 @@ def run(
self.log.info("Running statement: %s, parameters: %s", sql_statement, parameters)
if handler is not None:
result = handler(cur)
results.append(result)
if return_single_query_results(sql, return_last, split_statements):
_last_result = result
_last_description = cur.description
else:
results.append(result)
self.descriptions.append(cur.description)

self.log.info("Rows affected: %s", cur.rowcount)

Expand All @@ -188,8 +195,9 @@ def run(

if handler is None:
return None
elif return_single_query_results(sql, return_last, split_statements):
return results[-1]
if return_single_query_results(sql, return_last, split_statements):
self.descriptions = [_last_description]
return _last_result
else:
return results

Expand Down
228 changes: 228 additions & 0 deletions tests/providers/exasol/hooks/test_sql.py
@@ -0,0 +1,228 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

#
from __future__ import annotations

from unittest import mock
from unittest.mock import MagicMock, patch

import pytest

from airflow.models import Connection
from airflow.providers.common.sql.hooks.sql import fetch_all_handler
from airflow.providers.exasol.hooks.exasol import ExasolHook
from airflow.utils.session import provide_session

TASK_ID = "sql-operator"
HOST = "host"
DEFAULT_CONN_ID = "exasol_default"
PASSWORD = "password"


class ExasolHookForTests(ExasolHook):
conn_name_attr = "exasol_conn_id"
get_conn = MagicMock(name="conn")


@provide_session
@pytest.fixture(autouse=True)
def create_connection(session):
conn = session.query(Connection).filter(Connection.conn_id == DEFAULT_CONN_ID).first()
if conn is None:
conn = Connection(conn_id=DEFAULT_CONN_ID)
conn.host = HOST
conn.login = None
conn.password = PASSWORD
conn.extra = None
session.commit()


@pytest.fixture
def exasol_hook():
return ExasolHook()


def get_cursor_descriptions(fields: list[str]) -> list[tuple[str]]:
return [(field,) for field in fields]


index = 0


@pytest.mark.parametrize(
"return_last, split_statements, sql, cursor_calls,"
"cursor_descriptions, cursor_results, hook_descriptions, hook_results, ",
[
pytest.param(
True,
False,
"select * from test.test",
["select * from test.test"],
[["id", "value"]],
([[1, 2], [11, 12]],),
[[("id",), ("value",)]],
[[1, 2], [11, 12]],
id="The return_last set and no split statements set on single query in string",
),
pytest.param(
False,
False,
"select * from test.test;",
["select * from test.test;"],
[["id", "value"]],
([[1, 2], [11, 12]],),
[[("id",), ("value",)]],
[[1, 2], [11, 12]],
id="The return_last not set and no split statements set on single query in string",
),
pytest.param(
True,
True,
"select * from test.test;",
["select * from test.test;"],
[["id", "value"]],
([[1, 2], [11, 12]],),
[[("id",), ("value",)]],
[[1, 2], [11, 12]],
id="The return_last set and split statements set on single query in string",
),
pytest.param(
False,
True,
"select * from test.test;",
["select * from test.test;"],
[["id", "value"]],
([[1, 2], [11, 12]],),
[[("id",), ("value",)]],
[[[1, 2], [11, 12]]],
id="The return_last not set and split statements set on single query in string",
),
pytest.param(
True,
True,
"select * from test.test;select * from test.test2;",
["select * from test.test;", "select * from test.test2;"],
[["id", "value"], ["id2", "value2"]],
([[1, 2], [11, 12]], [[3, 4], [13, 14]]),
[[("id2",), ("value2",)]],
[[3, 4], [13, 14]],
id="The return_last set and split statements set on multiple queries in string",
), # Failing
pytest.param(
False,
True,
"select * from test.test;select * from test.test2;",
["select * from test.test;", "select * from test.test2;"],
[["id", "value"], ["id2", "value2"]],
([[1, 2], [11, 12]], [[3, 4], [13, 14]]),
[[("id",), ("value",)], [("id2",), ("value2",)]],
[[[1, 2], [11, 12]], [[3, 4], [13, 14]]],
id="The return_last not set and split statements set on multiple queries in string",
),
pytest.param(
True,
True,
["select * from test.test;"],
["select * from test.test"],
[["id", "value"]],
([[1, 2], [11, 12]],),
[[("id",), ("value",)]],
[[[1, 2], [11, 12]]],
id="The return_last set on single query in list",
),
pytest.param(
False,
True,
["select * from test.test;"],
["select * from test.test"],
[["id", "value"]],
([[1, 2], [11, 12]],),
[[("id",), ("value",)]],
[[[1, 2], [11, 12]]],
id="The return_last not set on single query in list",
),
pytest.param(
True,
True,
"select * from test.test;select * from test.test2;",
["select * from test.test", "select * from test.test2"],
[["id", "value"], ["id2", "value2"]],
([[1, 2], [11, 12]], [[3, 4], [13, 14]]),
[[("id2",), ("value2",)]],
[[3, 4], [13, 14]],
id="The return_last set set on multiple queries in list",
),
pytest.param(
False,
True,
"select * from test.test;select * from test.test2;",
["select * from test.test", "select * from test.test2"],
[["id", "value"], ["id2", "value2"]],
([[1, 2], [11, 12]], [[3, 4], [13, 14]]),
[[("id",), ("value",)], [("id2",), ("value2",)]],
[[[1, 2], [11, 12]], [[3, 4], [13, 14]]],
id="The return_last not set on multiple queries not set",
),
],
)
def test_query(
exasol_hook,
return_last,
split_statements,
sql,
cursor_calls,
cursor_descriptions,
cursor_results,
hook_descriptions,
hook_results,
):
with patch("airflow.providers.exasol.hooks.exasol.ExasolHook.get_conn") as mock_conn:
cursors = []
for index in range(len(cursor_descriptions)):
cur = mock.MagicMock(
rowcount=len(cursor_results[index]),
description=get_cursor_descriptions(cursor_descriptions[index]),
)
cur.fetchall.return_value = cursor_results[index]
cursors.append(cur)
mock_conn.execute.side_effect = cursors
mock_conn.return_value = mock_conn
results = exasol_hook.run(
sql=sql, handler=fetch_all_handler, return_last=return_last, split_statements=split_statements
)

assert exasol_hook.descriptions == hook_descriptions
assert exasol_hook.last_description == hook_descriptions[-1]
assert results == hook_results
cur.close.assert_called()


@pytest.mark.parametrize(
"empty_statement",
[
pytest.param([], id="Empty list"),
pytest.param("", id="Empty string"),
pytest.param("\n", id="Only EOL"),
],
)
def test_no_query(empty_statement):
dbapi_hook = ExasolHookForTests()
dbapi_hook.get_conn.return_value.cursor.rowcount = 0
with pytest.raises(ValueError) as err:
dbapi_hook.run(sql=empty_statement)
assert err.value.args[0] == "List of SQL statements is empty"

0 comments on commit 430e930

Please sign in to comment.