From a3bfa25e6756222b6811e92c7a9e9f8de47ab630 Mon Sep 17 00:00:00 2001 From: Aakash Nand Date: Wed, 2 Nov 2022 14:13:09 +0900 Subject: [PATCH] Allow setting client tags for trino connection (#27213) --- airflow/providers/trino/hooks/trino.py | 4 ++-- docs/apache-airflow-providers-trino/connections.rst | 1 + tests/providers/trino/hooks/test_trino.py | 13 ++++++++++++- 3 files changed, 15 insertions(+), 3 deletions(-) diff --git a/airflow/providers/trino/hooks/trino.py b/airflow/providers/trino/hooks/trino.py index e4be4a092ab15..56cf5d0795b78 100644 --- a/airflow/providers/trino/hooks/trino.py +++ b/airflow/providers/trino/hooks/trino.py @@ -98,7 +98,6 @@ def get_conn(self) -> Connection: extra = db.extra_dejson auth = None user = db.login - session_properties = extra.get("session_properties") if db.password and extra.get("auth") in ("kerberos", "certs"): raise AirflowException(f"The {extra.get('auth')!r} authorization type doesn't support password.") elif db.password: @@ -143,7 +142,8 @@ def get_conn(self) -> Connection: # type: ignore[func-returns-value] isolation_level=self.get_isolation_level(), verify=_boolify(extra.get("verify", True)), - session_properties=session_properties if session_properties else None, + session_properties=extra.get("session_properties") or None, + client_tags=extra.get("client_tags") or None, ) return trino_conn diff --git a/docs/apache-airflow-providers-trino/connections.rst b/docs/apache-airflow-providers-trino/connections.rst index 3dc279332d1ef..3c50cd1b4550a 100644 --- a/docs/apache-airflow-providers-trino/connections.rst +++ b/docs/apache-airflow-providers-trino/connections.rst @@ -52,3 +52,4 @@ Extra (optional, connection parameters) * ``certs__client_cert_path``, ``certs__client_key_path``- If certificate authentication should be used, the path to the client certificate and key is given via these parameters. * ``kerberos__service_name``, ``kerberos__config``, ``kerberos__mutual_authentication``, ``kerberos__force_preemptive``, ``kerberos__hostname_override``, ``kerberos__sanitize_mutual_error_response``, ``kerberos__principal``,``kerberos__delegate``, ``kerberos__ca_bundle`` - These parameters can be set when enabling ``kerberos`` authentication. * ``session_properties`` - JSON dictionary which allows to set session_properties. Example: ``{'session_properties':{'scale_writers':true,'task_writer_count:1'}}`` + * ``client_tags`` - List of comma separated tags. Example ``{'client_tags':['sales','cluster1']}``` diff --git a/tests/providers/trino/hooks/test_trino.py b/tests/providers/trino/hooks/test_trino.py index a30087952089d..9f17ec69a37db 100644 --- a/tests/providers/trino/hooks/test_trino.py +++ b/tests/providers/trino/hooks/test_trino.py @@ -169,6 +169,16 @@ def test_get_conn_session_properties(self, mock_connect, mock_get_connection): self.assert_connection_called_with(mock_connect, session_properties=extras["session_properties"]) + @patch(HOOK_GET_CONNECTION) + @patch(TRINO_DBAPI_CONNECT) + def test_get_conn_client_tags(self, mock_connect, mock_get_connection): + extras = {"client_tags": ["abc", "xyz"]} + + self.set_get_connection_return_value(mock_get_connection, extra=extras) + TrinoHook().get_conn() + + self.assert_connection_called_with(mock_connect, client_tags=extras["client_tags"]) + @parameterized.expand( [ ("False", False), @@ -195,7 +205,7 @@ def set_get_connection_return_value(mock_get_connection, extra=None, password=No @staticmethod def assert_connection_called_with( - mock_connect, http_headers=mock.ANY, auth=None, verify=True, session_properties=None + mock_connect, http_headers=mock.ANY, auth=None, verify=True, session_properties=None, client_tags=None ): mock_connect.assert_called_once_with( catalog="hive", @@ -210,6 +220,7 @@ def assert_connection_called_with( auth=None if not auth else auth.return_value, verify=verify, session_properties=session_properties, + client_tags=client_tags, )