diff --git a/airflow/providers/trino/hooks/trino.py b/airflow/providers/trino/hooks/trino.py index 070adabd01332..09fbe6efa66ba 100644 --- a/airflow/providers/trino/hooks/trino.py +++ b/airflow/providers/trino/hooks/trino.py @@ -95,7 +95,12 @@ def get_conn(self) -> Connection: elif db.password: auth = trino.auth.BasicAuthentication(db.login, db.password) # type: ignore[attr-defined] elif extra.get("auth") == "jwt": - auth = trino.auth.JWTAuthentication(token=extra.get("jwt__token")) + if "jwt__file" in extra: + with open(extra.get("jwt__file")) as jwt_file: + token = jwt_file.read() + else: + token = extra.get("jwt__token") + auth = trino.auth.JWTAuthentication(token=token) elif extra.get("auth") == "certs": auth = trino.auth.CertificateAuthentication( extra.get("certs__client_cert_path"), diff --git a/docs/apache-airflow-providers-trino/connections.rst b/docs/apache-airflow-providers-trino/connections.rst index 3c50cd1b4550a..ee25fbbce7183 100644 --- a/docs/apache-airflow-providers-trino/connections.rst +++ b/docs/apache-airflow-providers-trino/connections.rst @@ -49,7 +49,10 @@ Extra (optional, connection parameters) The following extra parameters can be used to configure authentication: * ``jwt__token`` - If jwt authentication should be used, the value of token is given via this parameter. + * ``jwt__file`` - If jwt authentication should be used, the location on disk for the file containing the jwt token. * ``certs__client_cert_path``, ``certs__client_key_path``- If certificate authentication should be used, the path to the client certificate and key is given via these parameters. * ``kerberos__service_name``, ``kerberos__config``, ``kerberos__mutual_authentication``, ``kerberos__force_preemptive``, ``kerberos__hostname_override``, ``kerberos__sanitize_mutual_error_response``, ``kerberos__principal``,``kerberos__delegate``, ``kerberos__ca_bundle`` - These parameters can be set when enabling ``kerberos`` authentication. * ``session_properties`` - JSON dictionary which allows to set session_properties. Example: ``{'session_properties':{'scale_writers':true,'task_writer_count:1'}}`` * ``client_tags`` - List of comma separated tags. Example ``{'client_tags':['sales','cluster1']}``` + + Note: If ``jwt__file`` and ``jwt__token`` are both given, ``jwt__file`` will take precedent. diff --git a/tests/providers/trino/hooks/test_trino.py b/tests/providers/trino/hooks/test_trino.py index 4a0f2e6d2f680..5a9e51bf096fe 100644 --- a/tests/providers/trino/hooks/test_trino.py +++ b/tests/providers/trino/hooks/test_trino.py @@ -18,7 +18,9 @@ from __future__ import annotations import json +import os import re +from tempfile import TemporaryDirectory from unittest import mock from unittest.mock import patch @@ -37,6 +39,19 @@ CERT_AUTHENTICATION = "airflow.providers.trino.hooks.trino.trino.auth.CertificateAuthentication" +@pytest.fixture() +def jwt_token_file(): + # Couldn't get this working with TemporaryFile, using TemporaryDirectory instead + # Save a phony jwt to a temporary file for the trino hook to read from + with TemporaryDirectory() as tmp_dir: + tmp_jwt_file = os.path.join(tmp_dir, "jwt.json") + + with open(tmp_jwt_file, "w") as tmp_file: + tmp_file.write('{"phony":"jwt"}') + + yield tmp_jwt_file + + class TestTrinoHookConn: @patch(BASIC_AUTHENTICATION) @patch(TRINO_DBAPI_CONNECT) @@ -110,6 +125,21 @@ def test_get_conn_jwt_auth(self, mock_get_connection, mock_connect, mock_jwt_aut TrinoHook().get_conn() self.assert_connection_called_with(mock_connect, auth=mock_jwt_auth) + @patch(JWT_AUTHENTICATION) + @patch(TRINO_DBAPI_CONNECT) + @patch(HOOK_GET_CONNECTION) + def test_get_conn_jwt_file(self, mock_get_connection, mock_connect, mock_jwt_auth, jwt_token_file): + extras = { + "auth": "jwt", + "jwt__file": jwt_token_file, + } + self.set_get_connection_return_value( + mock_get_connection, + extra=json.dumps(extras), + ) + TrinoHook().get_conn() + self.assert_connection_called_with(mock_connect, auth=mock_jwt_auth) + @patch(CERT_AUTHENTICATION) @patch(TRINO_DBAPI_CONNECT) @patch(HOOK_GET_CONNECTION)