Skip to content

Commit

Permalink
Allow setting client tags for trino connection (#27213)
Browse files Browse the repository at this point in the history
  • Loading branch information
aakashnand committed Nov 2, 2022
1 parent 1a3f785 commit a3bfa25
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 3 deletions.
4 changes: 2 additions & 2 deletions airflow/providers/trino/hooks/trino.py
Expand Up @@ -98,7 +98,6 @@ def get_conn(self) -> Connection:
extra = db.extra_dejson
auth = None
user = db.login
session_properties = extra.get("session_properties")
if db.password and extra.get("auth") in ("kerberos", "certs"):
raise AirflowException(f"The {extra.get('auth')!r} authorization type doesn't support password.")
elif db.password:
Expand Down Expand Up @@ -143,7 +142,8 @@ def get_conn(self) -> Connection:
# type: ignore[func-returns-value]
isolation_level=self.get_isolation_level(),
verify=_boolify(extra.get("verify", True)),
session_properties=session_properties if session_properties else None,
session_properties=extra.get("session_properties") or None,
client_tags=extra.get("client_tags") or None,
)

return trino_conn
Expand Down
1 change: 1 addition & 0 deletions docs/apache-airflow-providers-trino/connections.rst
Expand Up @@ -52,3 +52,4 @@ Extra (optional, connection parameters)
* ``certs__client_cert_path``, ``certs__client_key_path``- If certificate authentication should be used, the path to the client certificate and key is given via these parameters.
* ``kerberos__service_name``, ``kerberos__config``, ``kerberos__mutual_authentication``, ``kerberos__force_preemptive``, ``kerberos__hostname_override``, ``kerberos__sanitize_mutual_error_response``, ``kerberos__principal``,``kerberos__delegate``, ``kerberos__ca_bundle`` - These parameters can be set when enabling ``kerberos`` authentication.
* ``session_properties`` - JSON dictionary which allows to set session_properties. Example: ``{'session_properties':{'scale_writers':true,'task_writer_count:1'}}``
* ``client_tags`` - List of comma separated tags. Example ``{'client_tags':['sales','cluster1']}```
13 changes: 12 additions & 1 deletion tests/providers/trino/hooks/test_trino.py
Expand Up @@ -169,6 +169,16 @@ def test_get_conn_session_properties(self, mock_connect, mock_get_connection):

self.assert_connection_called_with(mock_connect, session_properties=extras["session_properties"])

@patch(HOOK_GET_CONNECTION)
@patch(TRINO_DBAPI_CONNECT)
def test_get_conn_client_tags(self, mock_connect, mock_get_connection):
extras = {"client_tags": ["abc", "xyz"]}

self.set_get_connection_return_value(mock_get_connection, extra=extras)
TrinoHook().get_conn()

self.assert_connection_called_with(mock_connect, client_tags=extras["client_tags"])

@parameterized.expand(
[
("False", False),
Expand All @@ -195,7 +205,7 @@ def set_get_connection_return_value(mock_get_connection, extra=None, password=No

@staticmethod
def assert_connection_called_with(
mock_connect, http_headers=mock.ANY, auth=None, verify=True, session_properties=None
mock_connect, http_headers=mock.ANY, auth=None, verify=True, session_properties=None, client_tags=None
):
mock_connect.assert_called_once_with(
catalog="hive",
Expand All @@ -210,6 +220,7 @@ def assert_connection_called_with(
auth=None if not auth else auth.return_value,
verify=verify,
session_properties=session_properties,
client_tags=client_tags,
)


Expand Down

0 comments on commit a3bfa25

Please sign in to comment.