Skip to content

Commit

Permalink
Fix fetch_all_handler & db-api tests for it (#25430)
Browse files Browse the repository at this point in the history
  • Loading branch information
FanatoniQ committed Aug 5, 2022
1 parent d3028ad commit d82436b
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 8 deletions.
9 changes: 3 additions & 6 deletions airflow/providers/common/sql/hooks/sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
import warnings
from contextlib import closing
from datetime import datetime
from typing import TYPE_CHECKING, Any, Callable, Iterable, List, Mapping, Optional, Tuple, Type, Union
from typing import Any, Callable, Iterable, List, Mapping, Optional, Tuple, Type, Union

import sqlparse
from packaging.version import Version
Expand All @@ -30,13 +30,10 @@
from airflow.utils.module_loading import import_string
from airflow.version import version

if TYPE_CHECKING:
from sqlalchemy.engine import CursorResult


def fetch_all_handler(cursor: 'CursorResult') -> Optional[List[Tuple]]:
def fetch_all_handler(cursor) -> Optional[List[Tuple]]:
"""Handler for DbApiHook.run() to return results"""
if cursor.returns_rows:
if cursor.description is not None:
return cursor.fetchall()
else:
return None
Expand Down
24 changes: 22 additions & 2 deletions tests/providers/common/sql/hooks/test_dbapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@

from airflow.hooks.base import BaseHook
from airflow.models import Connection
from airflow.providers.common.sql.hooks.sql import DbApiHook
from airflow.providers.common.sql.hooks.sql import DbApiHook, fetch_all_handler


class DbApiHookInProvider(DbApiHook):
Expand All @@ -40,7 +40,9 @@ class TestDbApiHook(unittest.TestCase):
def setUp(self):
super().setUp()

self.cur = mock.MagicMock(rowcount=0)
self.cur = mock.MagicMock(
rowcount=0, spec=["description", "rowcount", "execute", "fetchall", "close"]
)
self.conn = mock.MagicMock()
self.conn.cursor.return_value = self.cur
conn = self.conn
Expand Down Expand Up @@ -410,3 +412,21 @@ def test_instance_check_works_for_legacy_db_api_hook(self):
from airflow.hooks.dbapi import DbApiHook as LegacyDbApiHook

assert isinstance(DbApiHookInProvider(), LegacyDbApiHook)

def test_run_fetch_all_handler_select_1(self):
self.cur.rowcount = -1 # can be -1 according to pep249
self.cur.description = (tuple([None] * 7),)
query = "SELECT 1"
rows = [[1]]

self.cur.fetchall.return_value = rows
assert rows == self.db_hook.run(sql=query, handler=fetch_all_handler)

def test_run_fetch_all_handler_print(self):
self.cur.rowcount = -1
self.cur.description = None
query = "PRINT('Hello World !')"
rows = None

self.cur.fetchall.side_effect = Exception("Should not get called !")
assert rows == self.db_hook.run(sql=query, handler=fetch_all_handler)

0 comments on commit d82436b

Please sign in to comment.