diff --git a/superset/db_engine_specs/trino.py b/superset/db_engine_specs/trino.py index 8a56e8fb0ad3..8d294bfa89e4 100644 --- a/superset/db_engine_specs/trino.py +++ b/superset/db_engine_specs/trino.py @@ -15,7 +15,7 @@ # specific language governing permissions and limitations # under the License. from datetime import datetime -from typing import Any, Dict, List, Optional +from typing import Any, Dict, List, Optional, TYPE_CHECKING from urllib import parse import simplejson as json @@ -24,6 +24,9 @@ from superset.db_engine_specs.base import BaseEngineSpec from superset.utils import core as utils +if TYPE_CHECKING: + from superset.models.core import Database + class TrinoEngineSpec(BaseEngineSpec): engine = "trino" @@ -81,7 +84,6 @@ def update_impersonation_config( that can set the correct properties for impersonating users :param connect_args: config to be updated :param uri: URI string - :param impersonate_user: Flag indicating if impersonation is enabled :param username: Effective username :return: None """ @@ -116,9 +118,7 @@ def estimate_statement_cost(cls, statement: str, cursor: Any) -> Dict[str, Any]: Run a SQL query that estimates the cost of a given statement. :param statement: A single SQL statement - :param database: Database instance :param cursor: Cursor instance - :param username: Effective username :return: JSON response from Trino """ sql = f"EXPLAIN (TYPE IO, FORMAT JSON) {statement}" @@ -183,3 +183,22 @@ def humanize(value: Any, suffix: str) -> str: cost.append(statement_cost) return cost + + @staticmethod + def get_extra_params(database: "Database") -> Dict[str, Any]: + """ + Some databases require adding elements to connection parameters, + like passing certificates to `extra`. This can be done here. + + :param database: database instance from which to extract extras + :raises CertificateException: If certificate is not valid/unparseable + """ + extra: Dict[str, Any] = BaseEngineSpec.get_extra_params(database) + engine_params: Dict[str, Any] = extra.setdefault("engine_params", {}) + connect_args: Dict[str, Any] = engine_params.setdefault("connect_args", {}) + + if database.server_cert: + connect_args["http_scheme"] = "https" + connect_args["verify"] = utils.create_ssl_cert_file(database.server_cert) + + return extra diff --git a/tests/integration_tests/db_engine_specs/trino_tests.py b/tests/integration_tests/db_engine_specs/trino_tests.py index 211d00405ecd..e77e91603540 100644 --- a/tests/integration_tests/db_engine_specs/trino_tests.py +++ b/tests/integration_tests/db_engine_specs/trino_tests.py @@ -14,6 +14,9 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +import json +from unittest.mock import Mock, patch + from sqlalchemy.engine.url import URL from superset.db_engine_specs.trino import TrinoEngineSpec @@ -52,3 +55,35 @@ def test_adjust_database_uri_when_selected_schema_is_none(self): url.database = "hive/default" TrinoEngineSpec.adjust_database_uri(url, selected_schema=None) self.assertEqual(url.database, "hive/default") + + def test_get_extra_params(self): + database = Mock() + + database.extra = json.dumps({}) + database.server_cert = None + extra = TrinoEngineSpec.get_extra_params(database) + expected = {"engine_params": {"connect_args": {}}} + self.assertEqual(extra, expected) + + expected = { + "first": 1, + "engine_params": {"second": "two", "connect_args": {"third": "three"}}, + } + database.extra = json.dumps(expected) + database.server_cert = None + extra = TrinoEngineSpec.get_extra_params(database) + self.assertEqual(extra, expected) + + @patch("superset.utils.core.create_ssl_cert_file") + def test_get_extra_params_with_server_cert(self, create_ssl_cert_file_func: Mock): + database = Mock() + + database.extra = json.dumps({}) + database.server_cert = "TEST_CERT" + create_ssl_cert_file_func.return_value = "/path/to/tls.crt" + extra = TrinoEngineSpec.get_extra_params(database) + + connect_args = extra.get("engine_params", {}).get("connect_args", {}) + self.assertEqual(connect_args.get("http_scheme"), "https") + self.assertEqual(connect_args.get("verify"), "/path/to/tls.crt") + create_ssl_cert_file_func.assert_called_once_with(database.server_cert)