Skip to content

Commit

Permalink
Properly handle verify parameter in TrinoHook (#18791)
Browse files Browse the repository at this point in the history
  • Loading branch information
danarwix committed Oct 7, 2021
1 parent cfa8fe2 commit 6bc0f87
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 13 deletions.
13 changes: 4 additions & 9 deletions airflow/providers/trino/hooks/trino.py
Expand Up @@ -81,23 +81,18 @@ def get_conn(self) -> Connection:
delegate=_boolify(extra.get('kerberos__delegate', False)),
ca_bundle=extra.get('kerberos__ca_bundle'),
)

trino_conn = trino.dbapi.connect(
host=db.host,
port=db.port,
user=db.login,
source=db.extra_dejson.get('source', 'airflow'),
http_scheme=db.extra_dejson.get('protocol', 'http'),
catalog=db.extra_dejson.get('catalog', 'hive'),
source=extra.get('source', 'airflow'),
http_scheme=extra.get('protocol', 'http'),
catalog=extra.get('catalog', 'hive'),
schema=db.schema,
auth=auth,
isolation_level=self.get_isolation_level(), # type: ignore[func-returns-value]
verify=_boolify(extra.get('verify', True)),
)
if extra.get('verify') is not None:
# Unfortunately verify parameter is available via public API.
# The PR is merged in the trino library, but has not been released.
# See: https://github.com/trinodb/trino-python-client/pull/31
trino_conn._http_session.verify = _boolify(extra['verify'])

return trino_conn

Expand Down
2 changes: 1 addition & 1 deletion setup.py
Expand Up @@ -476,7 +476,7 @@ def write_version(filename: str = os.path.join(*[my_dir, "airflow", "git_version
telegram = [
'python-telegram-bot~=13.0',
]
trino = ['trino']
trino = ['trino>=0.301.0']
vertica = [
'vertica-python>=0.5.1',
]
Expand Down
18 changes: 15 additions & 3 deletions tests/providers/trino/hooks/test_trino.py
Expand Up @@ -51,6 +51,7 @@ def test_get_conn_basic_auth(self, mock_get_connection, mock_connect, mock_basic
user='login',
isolation_level=0,
auth=mock_basic_auth.return_value,
verify=True,
)
mock_basic_auth.assert_called_once_with('login', 'password')
assert mock_connect.return_value == conn
Expand Down Expand Up @@ -89,6 +90,7 @@ def test_get_conn_kerberos_auth(self, mock_get_connection, mock_connect, mock_au
'kerberos__principal': 'TEST_PRINCIPAL',
'kerberos__delegate': 'TEST_DELEGATE',
'kerberos__ca_bundle': 'TEST_CA_BUNDLE',
'verify': 'true',
}
),
)
Expand All @@ -104,6 +106,7 @@ def test_get_conn_kerberos_auth(self, mock_get_connection, mock_connect, mock_au
user='login',
isolation_level=0,
auth=mock_auth.return_value,
verify=True,
)
mock_auth.assert_called_once_with(
ca_bundle='TEST_CA_BUNDLE',
Expand Down Expand Up @@ -135,11 +138,20 @@ def test_get_conn_verify(self, current_verify, expected_verify):
mock_get_connection.return_value = Connection(
login='login', host='host', schema='hive', extra=json.dumps({'verify': current_verify})
)
mock_verify = mock.PropertyMock()
type(mock_connect.return_value._http_session).verify = mock_verify

conn = TrinoHook().get_conn()
mock_verify.assert_called_once_with(expected_verify)
mock_connect.assert_called_once_with(
catalog='hive',
host='host',
port=None,
http_scheme='http',
schema='hive',
source='airflow',
user='login',
auth=None,
isolation_level=0,
verify=expected_verify,
)
assert mock_connect.return_value == conn


Expand Down

0 comments on commit 6bc0f87

Please sign in to comment.