Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add enforce URI query params with a specific for MySQL #23723

Merged
merged 3 commits into from
Apr 18, 2023
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
8 changes: 5 additions & 3 deletions superset/db_engine_specs/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -357,6 +357,8 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
top_keywords: Set[str] = {"TOP"}
# A set of disallowed connection query parameters
disallow_uri_query_params: Set[str] = set()
# A Dict of query parameters that will always be used on every connection
enforce_uri_query_params: Dict[str, Any] = {}

force_column_alias_quotes = False
arraysize = 0
Expand Down Expand Up @@ -1089,11 +1091,11 @@ def adjust_engine_params( # pylint: disable=unused-argument
``supports_dynamic_schema`` set to true, so that Superset knows in which schema a
given query is running in order to enforce permissions (see #23385 and #23401).

Currently, changing the catalog is not supported. The method acceps a catalog so
that when catalog support is added to Superse the interface remains the same. This
Currently, changing the catalog is not supported. The method accepts a catalog so
that when catalog support is added to Superset the interface remains the same. This
is important because DB engine specs can be installed from 3rd party packages.
"""
return uri, connect_args
return uri, {**connect_args, **cls.enforce_uri_query_params}

@classmethod
def patch(cls) -> None:
Expand Down
6 changes: 5 additions & 1 deletion superset/db_engine_specs/mysql.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,7 @@ class MySQLEngineSpec(BaseEngineSpec, BasicParametersMixin):
),
}
disallow_uri_query_params = {"local_infile"}
enforce_uri_query_params = {"local_infile": 0}

@classmethod
def convert_dttm(
Expand All @@ -198,10 +199,13 @@ def adjust_engine_params(
catalog: Optional[str] = None,
schema: Optional[str] = None,
) -> Tuple[URL, Dict[str, Any]]:
uri, new_connect_args = super(
MySQLEngineSpec, MySQLEngineSpec
).adjust_engine_params(uri, connect_args, catalog, schema)
if schema:
uri = uri.set(database=parse.quote(schema, safe=""))

return uri, connect_args
return uri, new_connect_args

@classmethod
def get_schema_from_engine_params(
Expand Down
15 changes: 15 additions & 0 deletions tests/integration_tests/model_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,21 @@ def test_impersonate_user_presto(self, mocked_create_engine):
"password": "original_user_password",
}

@unittest.skipUnless(
SupersetTestCase.is_module_installed("MySQLdb"), "mysqlclient not installed"
)
@mock.patch("superset.models.core.create_engine")
def test_adjust_engine_params_mysql(self, mocked_create_engine):
model = Database(
database_name="test_database",
sqlalchemy_uri="mysql://user:password@localhost",
)
model._get_sqla_engine()
call_args = mocked_create_engine.call_args

assert str(call_args[0][0]) == "mysql://user:password@localhost"
assert call_args[1]["connect_args"]["local_infile"] == 0

@mock.patch("superset.models.core.create_engine")
def test_impersonate_user_trino(self, mocked_create_engine):
principal_user = security_manager.find_user(username="gamma")
Expand Down
34 changes: 32 additions & 2 deletions tests/unit_tests/db_engine_specs/test_mysql.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
# under the License.

from datetime import datetime
from typing import Any, Dict, Optional, Type
from typing import Any, Dict, Optional, Tuple, Type
from unittest.mock import Mock, patch

import pytest
Expand All @@ -33,7 +33,7 @@
TINYINT,
TINYTEXT,
)
from sqlalchemy.engine.url import make_url
from sqlalchemy.engine.url import make_url, URL

from superset.utils.core import GenericDataType
from tests.unit_tests.db_engine_specs.utils import (
Expand Down Expand Up @@ -119,6 +119,36 @@ def test_validate_database_uri(sqlalchemy_uri: str, error: bool) -> None:
MySQLEngineSpec.validate_database_uri(url)


@pytest.mark.parametrize(
"sqlalchemy_uri,connect_args,returns",
[
("mysql://user:password@host/db1", {"local_infile": 1}, {"local_infile": 0}),
("mysql://user:password@host/db1", {"local_infile": -1}, {"local_infile": 0}),
("mysql://user:password@host/db1", {"local_infile": 0}, {"local_infile": 0}),
(
"mysql://user:password@host/db1",
{"param1": "some_value"},
{"local_infile": 0, "param1": "some_value"},
),
(
"mysql://user:password@host/db1",
{"local_infile": 1, "param1": "some_value"},
{"local_infile": 0, "param1": "some_value"},
),
],
)
def test_adjust_engine_params(
sqlalchemy_uri: str, connect_args: Dict[str, Any], returns: Dict[str, Any]
) -> None:
from superset.db_engine_specs.mysql import MySQLEngineSpec

url = make_url(sqlalchemy_uri)
returned_url, returned_connect_args = MySQLEngineSpec.adjust_engine_params(
url, connect_args
)
assert returned_connect_args == returns


@patch("sqlalchemy.engine.Engine.connect")
def test_get_cancel_query_id(engine_mock: Mock) -> None:
from superset.db_engine_specs.mysql import MySQLEngineSpec
Expand Down