Skip to content

Commit

Permalink
feat: add enforce URI query params with a specific for MySQL (#23723)
Browse files Browse the repository at this point in the history
  • Loading branch information
dpgaspar authored and eschutho committed May 31, 2023
1 parent 70bdf40 commit 48ea6c0
Show file tree
Hide file tree
Showing 4 changed files with 58 additions and 4 deletions.
9 changes: 8 additions & 1 deletion superset/db_engine_specs/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -356,6 +356,8 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
top_keywords: Set[str] = {"TOP"}
# A set of disallowed connection query parameters
disallow_uri_query_params: Set[str] = set()
# A Dict of query parameters that will always be used on every connection
enforce_uri_query_params: Dict[str, Any] = {}

force_column_alias_quotes = False
arraysize = 0
Expand Down Expand Up @@ -1016,8 +1018,13 @@ def adjust_database_uri( # pylint: disable=unused-argument
Some database drivers like Presto accept '{catalog}/{schema}' in
the database component of the URL, that can be handled here.
Currently, changing the catalog is not supported. The method accepts a catalog so
that when catalog support is added to Superset the interface remains the same.
This is important because DB engine specs can be installed from 3rd party
packages.
"""
return uri
return uri, {**cls.enforce_uri_query_params}

@classmethod
def patch(cls) -> None:
Expand Down
6 changes: 5 additions & 1 deletion superset/db_engine_specs/mysql.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,7 @@ class MySQLEngineSpec(BaseEngineSpec, BasicParametersMixin):
),
}
disallow_uri_query_params = {"local_infile"}
enforce_uri_query_params = {"local_infile": 0}

@classmethod
def convert_dttm(
Expand All @@ -192,10 +193,13 @@ def convert_dttm(
def adjust_database_uri(
cls, uri: URL, selected_schema: Optional[str] = None
) -> URL:
uri, new_connect_args = super(
MySQLEngineSpec, MySQLEngineSpec
).adjust_database_uri(uri)
if selected_schema:
uri = uri.set(database=parse.quote(selected_schema, safe=""))

return uri
return uri, new_connect_args

@classmethod
def get_datatype(cls, type_code: Any) -> Optional[str]:
Expand Down
15 changes: 15 additions & 0 deletions tests/integration_tests/model_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,21 @@ def test_impersonate_user_presto(self, mocked_create_engine):
"password": "original_user_password",
}

@unittest.skipUnless(
SupersetTestCase.is_module_installed("MySQLdb"), "mysqlclient not installed"
)
@mock.patch("superset.models.core.create_engine")
def test_adjust_engine_params_mysql(self, mocked_create_engine):
model = Database(
database_name="test_database",
sqlalchemy_uri="mysql://user:password@localhost",
)
model._get_sqla_engine()
call_args = mocked_create_engine.call_args

assert str(call_args[0][0]) == "mysql://user:password@localhost"
assert call_args[1]["connect_args"]["local_infile"] == 0

@mock.patch("superset.models.core.create_engine")
def test_impersonate_user_trino(self, mocked_create_engine):
principal_user = security_manager.find_user(username="gamma")
Expand Down
32 changes: 30 additions & 2 deletions tests/unit_tests/db_engine_specs/test_mysql.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
# under the License.

from datetime import datetime
from typing import Any, Dict, Optional, Type
from typing import Any, Dict, Optional, Tuple, Type
from unittest.mock import Mock, patch

import pytest
Expand All @@ -33,7 +33,7 @@
TINYINT,
TINYTEXT,
)
from sqlalchemy.engine.url import make_url
from sqlalchemy.engine.url import make_url, URL

from superset.utils.core import GenericDataType
from tests.unit_tests.db_engine_specs.utils import (
Expand Down Expand Up @@ -119,6 +119,34 @@ def test_validate_database_uri(sqlalchemy_uri: str, error: bool) -> None:
MySQLEngineSpec.validate_database_uri(url)


@pytest.mark.parametrize(
"sqlalchemy_uri,connect_args,returns",
[
("mysql://user:password@host/db1", {"local_infile": 1}, {"local_infile": 0}),
("mysql://user:password@host/db1", {"local_infile": -1}, {"local_infile": 0}),
("mysql://user:password@host/db1", {"local_infile": 0}, {"local_infile": 0}),
(
"mysql://user:password@host/db1",
{"param1": "some_value"},
{"local_infile": 0, "param1": "some_value"},
),
(
"mysql://user:password@host/db1",
{"local_infile": 1, "param1": "some_value"},
{"local_infile": 0, "param1": "some_value"},
),
],
)
def test_adjust_database_uri(
sqlalchemy_uri: str, connect_args: Dict[str, Any], returns: Dict[str, Any]
) -> None:
from superset.db_engine_specs.mysql import MySQLEngineSpec

url = make_url(sqlalchemy_uri)
returned_url, returned_connect_args = MySQLEngineSpec.adjust_database_uri(url)
assert returned_connect_args == returns


@patch("sqlalchemy.engine.Engine.connect")
def test_get_cancel_query_id(engine_mock: Mock) -> None:
from superset.db_engine_specs.mysql import MySQLEngineSpec
Expand Down

0 comments on commit 48ea6c0

Please sign in to comment.