diff --git a/superset/db_engine_specs/base.py b/superset/db_engine_specs/base.py index 63ee7d7f9dcd..a3647034dcf6 100644 --- a/superset/db_engine_specs/base.py +++ b/superset/db_engine_specs/base.py @@ -354,6 +354,8 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods # This set will give the keywords for data limit statements # to consider for the engines with TOP SQL parsing top_keywords: Set[str] = {"TOP"} + # A set of disallowed connection query parameters + disallow_uri_query_params: Set[str] = set() force_column_alias_quotes = False arraysize = 0 @@ -1724,6 +1726,19 @@ def get_public_information(cls) -> Dict[str, Any]: "disable_ssh_tunneling": cls.disable_ssh_tunneling, } + @classmethod + def validate_database_uri(cls, sqlalchemy_uri: URL) -> None: + """ + Validates a database SQLAlchemy URI per engine spec. + Use this to implement a final validation for unwanted connection configuration + + :param sqlalchemy_uri: + """ + if existing_disallowed := cls.disallow_uri_query_params.intersection( + sqlalchemy_uri.query + ): + raise ValueError(f"Forbidden query parameter(s): {existing_disallowed}") + # schema for adding a database by providing parameters instead of the # full SQLAlchemy URI diff --git a/superset/db_engine_specs/mysql.py b/superset/db_engine_specs/mysql.py index b873daff7560..348b3287e35d 100644 --- a/superset/db_engine_specs/mysql.py +++ b/superset/db_engine_specs/mysql.py @@ -173,6 +173,7 @@ class MySQLEngineSpec(BaseEngineSpec, BasicParametersMixin): {}, ), } + disallow_uri_query_params = {"local_infile"} @classmethod def convert_dttm( diff --git a/superset/models/core.py b/superset/models/core.py index ac7cc517ef0d..9c67a2efa6d2 100755 --- a/superset/models/core.py +++ b/superset/models/core.py @@ -424,6 +424,8 @@ def _get_sqla_engine( sqlalchemy_url = make_url_safe( sqlalchemy_uri if sqlalchemy_uri else self.sqlalchemy_uri_decrypted ) + self.db_engine_spec.validate_database_uri(sqlalchemy_url) + sqlalchemy_url = self.db_engine_spec.adjust_database_uri(sqlalchemy_url, schema) effective_username = self.get_effective_user(sqlalchemy_url) # If using MySQL or Presto for example, will set url.username diff --git a/tests/unit_tests/db_engine_specs/test_mysql.py b/tests/unit_tests/db_engine_specs/test_mysql.py index 4562e497c6e6..a512e71a97f6 100644 --- a/tests/unit_tests/db_engine_specs/test_mysql.py +++ b/tests/unit_tests/db_engine_specs/test_mysql.py @@ -33,6 +33,7 @@ TINYINT, TINYTEXT, ) +from sqlalchemy.engine.url import make_url from superset.utils.core import GenericDataType from tests.unit_tests.db_engine_specs.utils import ( @@ -99,6 +100,25 @@ def test_convert_dttm( assert_convert_dttm(spec, target_type, expected_result, dttm) +@pytest.mark.parametrize( + "sqlalchemy_uri,error", + [ + ("mysql://user:password@host/db1?local_infile=1", True), + ("mysql://user:password@host/db1?local_infile=0", True), + ("mysql://user:password@host/db1", False), + ], +) +def test_validate_database_uri(sqlalchemy_uri: str, error: bool) -> None: + from superset.db_engine_specs.mysql import MySQLEngineSpec + + url = make_url(sqlalchemy_uri) + if error: + with pytest.raises(ValueError): + MySQLEngineSpec.validate_database_uri(url) + return + MySQLEngineSpec.validate_database_uri(url) + + @patch("sqlalchemy.engine.Engine.connect") def test_get_cancel_query_id(engine_mock: Mock) -> None: from superset.db_engine_specs.mysql import MySQLEngineSpec