Skip to content

Commit

Permalink
Pass Trino hook params to DbApiHook (#21479)
Browse files Browse the repository at this point in the history
  • Loading branch information
kazanzhy committed Feb 15, 2022
1 parent dc03000 commit 1884f22
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 8 deletions.
20 changes: 14 additions & 6 deletions airflow/providers/trino/hooks/trino.py
Expand Up @@ -16,7 +16,7 @@
# specific language governing permissions and limitations
# under the License.
import os
from typing import Any, Iterable, Optional
from typing import Any, Callable, Iterable, Optional

import trino
from trino.exceptions import DatabaseError
Expand Down Expand Up @@ -106,7 +106,7 @@ def get_isolation_level(self) -> Any:
def _strip_sql(sql: str) -> str:
return sql.strip().rstrip(';')

def get_records(self, hql, parameters: Optional[dict] = None):
def get_records(self, hql: str, parameters: Optional[dict] = None):
"""Get a set of records from Trino"""
try:
return super().get_records(self._strip_sql(hql), parameters)
Expand All @@ -120,7 +120,7 @@ def get_first(self, hql: str, parameters: Optional[dict] = None) -> Any:
except DatabaseError as e:
raise TrinoException(e)

def get_pandas_df(self, hql, parameters=None, **kwargs):
def get_pandas_df(self, hql: str, parameters: Optional[dict] = None, **kwargs): # type: ignore[override]
"""Get a pandas dataframe from a sql query."""
import pandas

Expand All @@ -138,9 +138,17 @@ def get_pandas_df(self, hql, parameters=None, **kwargs):
df = pandas.DataFrame(**kwargs)
return df

def run(self, hql, autocommit: bool = False, parameters: Optional[dict] = None, handler=None) -> None:
def run(
self,
hql: str,
autocommit: bool = False,
parameters: Optional[dict] = None,
handler: Optional[Callable] = None,
) -> None:
"""Execute the statement against Trino. Can be used to create views."""
return super().run(sql=self._strip_sql(hql), parameters=parameters)
return super().run(
sql=self._strip_sql(hql), autocommit=autocommit, parameters=parameters, handler=handler
)

def insert_rows(
self,
Expand Down Expand Up @@ -169,4 +177,4 @@ def insert_rows(
)
commit_every = 0

super().insert_rows(table, rows, target_fields, commit_every)
super().insert_rows(table, rows, target_fields, commit_every, replace)
14 changes: 12 additions & 2 deletions tests/providers/trino/hooks/test_trino.py
Expand Up @@ -149,8 +149,9 @@ def test_insert_rows(self, mock_insert_rows):
rows = [("hello",), ("world",)]
target_fields = None
commit_every = 10
self.db_hook.insert_rows(table, rows, target_fields, commit_every)
mock_insert_rows.assert_called_once_with(table, rows, None, 10)
replace = True
self.db_hook.insert_rows(table, rows, target_fields, commit_every, replace)
mock_insert_rows.assert_called_once_with(table, rows, None, 10, True)

def test_get_first_record(self):
statement = 'SQL'
Expand Down Expand Up @@ -187,6 +188,15 @@ def test_get_pandas_df(self):

self.cur.execute.assert_called_once_with(statement, None)

@patch('airflow.hooks.dbapi.DbApiHook.run')
def test_run(self, mock_run):
hql = "SELECT 1"
autocommit = False
parameters = {"hello": "world"}
handler = str
self.db_hook.run(hql, autocommit, parameters, handler)
mock_run.assert_called_once_with(sql=hql, autocommit=False, parameters=parameters, handler=str)


class TestTrinoHookIntegration(unittest.TestCase):
@pytest.mark.integration("trino")
Expand Down

0 comments on commit 1884f22

Please sign in to comment.