Skip to content

Commit

Permalink
feat: trino support server-cert (#16346)
Browse files Browse the repository at this point in the history
Signed-off-by: Đặng Minh Dũng <dungdm93@live.com>
  • Loading branch information
dungdm93 committed Nov 26, 2021
1 parent ff68502 commit ebb3419
Show file tree
Hide file tree
Showing 2 changed files with 58 additions and 4 deletions.
27 changes: 23 additions & 4 deletions superset/db_engine_specs/trino.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
# specific language governing permissions and limitations
# under the License.
from datetime import datetime
from typing import Any, Dict, List, Optional
from typing import Any, Dict, List, Optional, TYPE_CHECKING
from urllib import parse

import simplejson as json
Expand All @@ -24,6 +24,9 @@
from superset.db_engine_specs.base import BaseEngineSpec
from superset.utils import core as utils

if TYPE_CHECKING:
from superset.models.core import Database


class TrinoEngineSpec(BaseEngineSpec):
engine = "trino"
Expand Down Expand Up @@ -81,7 +84,6 @@ def update_impersonation_config(
that can set the correct properties for impersonating users
:param connect_args: config to be updated
:param uri: URI string
:param impersonate_user: Flag indicating if impersonation is enabled
:param username: Effective username
:return: None
"""
Expand Down Expand Up @@ -116,9 +118,7 @@ def estimate_statement_cost(cls, statement: str, cursor: Any) -> Dict[str, Any]:
Run a SQL query that estimates the cost of a given statement.
:param statement: A single SQL statement
:param database: Database instance
:param cursor: Cursor instance
:param username: Effective username
:return: JSON response from Trino
"""
sql = f"EXPLAIN (TYPE IO, FORMAT JSON) {statement}"
Expand Down Expand Up @@ -183,3 +183,22 @@ def humanize(value: Any, suffix: str) -> str:
cost.append(statement_cost)

return cost

@staticmethod
def get_extra_params(database: "Database") -> Dict[str, Any]:
"""
Some databases require adding elements to connection parameters,
like passing certificates to `extra`. This can be done here.
:param database: database instance from which to extract extras
:raises CertificateException: If certificate is not valid/unparseable
"""
extra: Dict[str, Any] = BaseEngineSpec.get_extra_params(database)
engine_params: Dict[str, Any] = extra.setdefault("engine_params", {})
connect_args: Dict[str, Any] = engine_params.setdefault("connect_args", {})

if database.server_cert:
connect_args["http_scheme"] = "https"
connect_args["verify"] = utils.create_ssl_cert_file(database.server_cert)

return extra
35 changes: 35 additions & 0 deletions tests/integration_tests/db_engine_specs/trino_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,9 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import json
from unittest.mock import Mock, patch

from sqlalchemy.engine.url import URL

from superset.db_engine_specs.trino import TrinoEngineSpec
Expand Down Expand Up @@ -52,3 +55,35 @@ def test_adjust_database_uri_when_selected_schema_is_none(self):
url.database = "hive/default"
TrinoEngineSpec.adjust_database_uri(url, selected_schema=None)
self.assertEqual(url.database, "hive/default")

def test_get_extra_params(self):
database = Mock()

database.extra = json.dumps({})
database.server_cert = None
extra = TrinoEngineSpec.get_extra_params(database)
expected = {"engine_params": {"connect_args": {}}}
self.assertEqual(extra, expected)

expected = {
"first": 1,
"engine_params": {"second": "two", "connect_args": {"third": "three"}},
}
database.extra = json.dumps(expected)
database.server_cert = None
extra = TrinoEngineSpec.get_extra_params(database)
self.assertEqual(extra, expected)

@patch("superset.utils.core.create_ssl_cert_file")
def test_get_extra_params_with_server_cert(self, create_ssl_cert_file_func: Mock):
database = Mock()

database.extra = json.dumps({})
database.server_cert = "TEST_CERT"
create_ssl_cert_file_func.return_value = "/path/to/tls.crt"
extra = TrinoEngineSpec.get_extra_params(database)

connect_args = extra.get("engine_params", {}).get("connect_args", {})
self.assertEqual(connect_args.get("http_scheme"), "https")
self.assertEqual(connect_args.get("verify"), "/path/to/tls.crt")
create_ssl_cert_file_func.assert_called_once_with(database.server_cert)

0 comments on commit ebb3419

Please sign in to comment.