Skip to content

Commit

Permalink
Trino Hook: Add ability to read JWT from file (#31950)
Browse files Browse the repository at this point in the history

---------

Co-authored-by: Joshua H. Bigler <joshua.bigler@pnnl.gov>
Co-authored-by: Joshua Bigler <joshuab@joshuabcentos8.pnl.gov>
Co-authored-by: Phani Kumar <94376113+phanikumv@users.noreply.github.com>
  • Loading branch information
4 people committed Jun 24, 2023
1 parent 9a35a40 commit 371833e
Show file tree
Hide file tree
Showing 3 changed files with 39 additions and 1 deletion.
7 changes: 6 additions & 1 deletion airflow/providers/trino/hooks/trino.py
Expand Up @@ -95,7 +95,12 @@ def get_conn(self) -> Connection:
elif db.password:
auth = trino.auth.BasicAuthentication(db.login, db.password) # type: ignore[attr-defined]
elif extra.get("auth") == "jwt":
auth = trino.auth.JWTAuthentication(token=extra.get("jwt__token"))
if "jwt__file" in extra:
with open(extra.get("jwt__file")) as jwt_file:
token = jwt_file.read()
else:
token = extra.get("jwt__token")
auth = trino.auth.JWTAuthentication(token=token)
elif extra.get("auth") == "certs":
auth = trino.auth.CertificateAuthentication(
extra.get("certs__client_cert_path"),
Expand Down
3 changes: 3 additions & 0 deletions docs/apache-airflow-providers-trino/connections.rst
Expand Up @@ -49,7 +49,10 @@ Extra (optional, connection parameters)
The following extra parameters can be used to configure authentication:

* ``jwt__token`` - If jwt authentication should be used, the value of token is given via this parameter.
* ``jwt__file`` - If jwt authentication should be used, the location on disk for the file containing the jwt token.
* ``certs__client_cert_path``, ``certs__client_key_path``- If certificate authentication should be used, the path to the client certificate and key is given via these parameters.
* ``kerberos__service_name``, ``kerberos__config``, ``kerberos__mutual_authentication``, ``kerberos__force_preemptive``, ``kerberos__hostname_override``, ``kerberos__sanitize_mutual_error_response``, ``kerberos__principal``,``kerberos__delegate``, ``kerberos__ca_bundle`` - These parameters can be set when enabling ``kerberos`` authentication.
* ``session_properties`` - JSON dictionary which allows to set session_properties. Example: ``{'session_properties':{'scale_writers':true,'task_writer_count:1'}}``
* ``client_tags`` - List of comma separated tags. Example ``{'client_tags':['sales','cluster1']}```

Note: If ``jwt__file`` and ``jwt__token`` are both given, ``jwt__file`` will take precedent.
30 changes: 30 additions & 0 deletions tests/providers/trino/hooks/test_trino.py
Expand Up @@ -18,7 +18,9 @@
from __future__ import annotations

import json
import os
import re
from tempfile import TemporaryDirectory
from unittest import mock
from unittest.mock import patch

Expand All @@ -37,6 +39,19 @@
CERT_AUTHENTICATION = "airflow.providers.trino.hooks.trino.trino.auth.CertificateAuthentication"


@pytest.fixture()
def jwt_token_file():
# Couldn't get this working with TemporaryFile, using TemporaryDirectory instead
# Save a phony jwt to a temporary file for the trino hook to read from
with TemporaryDirectory() as tmp_dir:
tmp_jwt_file = os.path.join(tmp_dir, "jwt.json")

with open(tmp_jwt_file, "w") as tmp_file:
tmp_file.write('{"phony":"jwt"}')

yield tmp_jwt_file


class TestTrinoHookConn:
@patch(BASIC_AUTHENTICATION)
@patch(TRINO_DBAPI_CONNECT)
Expand Down Expand Up @@ -110,6 +125,21 @@ def test_get_conn_jwt_auth(self, mock_get_connection, mock_connect, mock_jwt_aut
TrinoHook().get_conn()
self.assert_connection_called_with(mock_connect, auth=mock_jwt_auth)

@patch(JWT_AUTHENTICATION)
@patch(TRINO_DBAPI_CONNECT)
@patch(HOOK_GET_CONNECTION)
def test_get_conn_jwt_file(self, mock_get_connection, mock_connect, mock_jwt_auth, jwt_token_file):
extras = {
"auth": "jwt",
"jwt__file": jwt_token_file,
}
self.set_get_connection_return_value(
mock_get_connection,
extra=json.dumps(extras),
)
TrinoHook().get_conn()
self.assert_connection_called_with(mock_connect, auth=mock_jwt_auth)

@patch(CERT_AUTHENTICATION)
@patch(TRINO_DBAPI_CONNECT)
@patch(HOOK_GET_CONNECTION)
Expand Down

0 comments on commit 371833e

Please sign in to comment.