diff --git a/superset/connectors/sqla/models.py b/superset/connectors/sqla/models.py index fcb40f2199b5..b1f1e7c606cc 100644 --- a/superset/connectors/sqla/models.py +++ b/superset/connectors/sqla/models.py @@ -64,7 +64,7 @@ from sqlalchemy.orm import backref, Query, relationship, RelationshipProperty, Session from sqlalchemy.orm.mapper import Mapper from sqlalchemy.schema import UniqueConstraint -from sqlalchemy.sql import column, ColumnElement, literal_column, table, text +from sqlalchemy.sql import column, ColumnElement, literal_column, table from sqlalchemy.sql.elements import ColumnClause, TextClause from sqlalchemy.sql.expression import Label, Select, TextAsFrom from sqlalchemy.sql.selectable import Alias, TableClause @@ -102,6 +102,16 @@ VIRTUAL_TABLE_ALIAS = "virtual_table" +def text(clause: str) -> TextClause: + """ + SQLALchemy wrapper to ensure text clauses are escaped properly + + :param clause: clause potentially containing colons + :return: text clause with escaped colons + """ + return sa.text(clause.replace(":", "\\:")) + + class SqlaQuery(NamedTuple): applied_template_filters: List[str] extra_cache_keys: List[Any] @@ -793,7 +803,7 @@ def get_from_clause( raise QueryObjectValidationError( _("Virtual dataset query must be read-only") ) - return TextAsFrom(sa.text(from_sql), []).alias(VIRTUAL_TABLE_ALIAS) + return TextAsFrom(text(from_sql), []).alias(VIRTUAL_TABLE_ALIAS) def get_rendered_sql( self, template_processor: Optional[BaseTemplateProcessor] = None @@ -814,8 +824,6 @@ def get_rendered_sql( ) ) from ex sql = sqlparse.format(sql.strip("\t\r\n; "), strip_comments=True) - # we need to escape strings that SQLAlchemy interprets as bind parameters - sql = utils.escape_sqla_query_binds(sql) if not sql: raise QueryObjectValidationError(_("Virtual dataset query cannot be empty")) if len(sqlparse.split(sql)) > 1: @@ -1265,7 +1273,7 @@ def get_sqla_query( # pylint: disable=too-many-arguments,too-many-locals,too-ma msg=ex.message, ) ) from ex - where_clause_and += [sa.text("({})".format(where))] + where_clause_and += [text(f"({where})")] having = extras.get("having") if having: try: @@ -1277,7 +1285,7 @@ def get_sqla_query( # pylint: disable=too-many-arguments,too-many-locals,too-ma msg=ex.message, ) ) from ex - having_clause_and += [sa.text("({})".format(having))] + having_clause_and += [text(f"({having})")] if apply_fetch_values_predicate and self.fetch_values_predicate: qry = qry.where(self.get_fetch_values_predicate()) if granularity: diff --git a/superset/utils/core.py b/superset/utils/core.py index 9f30c64d91f5..71d59348c100 100644 --- a/superset/utils/core.py +++ b/superset/utils/core.py @@ -80,7 +80,6 @@ from sqlalchemy.dialects.mysql import MEDIUMTEXT from sqlalchemy.engine import Connection, Engine from sqlalchemy.engine.reflection import Inspector -from sqlalchemy.sql.elements import TextClause from sqlalchemy.sql.type_api import Variant from sqlalchemy.types import TEXT, TypeDecorator, TypeEngine from typing_extensions import TypedDict, TypeGuard @@ -132,8 +131,6 @@ InputType = TypeVar("InputType") -BIND_PARAM_REGEX = TextClause._bind_params_regex # pylint: disable=protected-access - class LenientEnum(Enum): """Enums with a `get` method that convert a enum value to `Enum` if it is a @@ -1787,29 +1784,3 @@ def apply_max_row_limit(limit: int, max_limit: Optional[int] = None,) -> int: if limit != 0: return min(max_limit, limit) return max_limit - - -def escape_sqla_query_binds(sql: str) -> str: - """ - Replace strings in a query that SQLAlchemy would otherwise interpret as - bind parameters. - - :param sql: unescaped query string - :return: escaped query string - >>> escape_sqla_query_binds("select ':foo'") - "select '\\\\:foo'" - >>> escape_sqla_query_binds("select 'foo'::TIMESTAMP") - "select 'foo'::TIMESTAMP" - >>> escape_sqla_query_binds("select ':foo :bar'::TIMESTAMP") - "select '\\\\:foo \\\\:bar'::TIMESTAMP" - >>> escape_sqla_query_binds("select ':foo :foo :bar'::TIMESTAMP") - "select '\\\\:foo \\\\:foo \\\\:bar'::TIMESTAMP" - """ - matches = BIND_PARAM_REGEX.finditer(sql) - processed_binds = set() - for match in matches: - bind = match.group(0) - if bind not in processed_binds: - sql = sql.replace(bind, bind.replace(":", "\\:")) - processed_binds.add(bind) - return sql diff --git a/tests/integration_tests/charts/api_tests.py b/tests/integration_tests/charts/api_tests.py index 3647442eba18..fd228b6e3da6 100644 --- a/tests/integration_tests/charts/api_tests.py +++ b/tests/integration_tests/charts/api_tests.py @@ -20,7 +20,7 @@ import unittest from datetime import datetime from io import BytesIO -from typing import Optional +from typing import Optional, List from unittest import mock from zipfile import is_zipfile, ZipFile @@ -42,6 +42,7 @@ load_world_bank_dashboard_with_slices, ) from tests.integration_tests.test_app import app +from superset import security_manager from superset.charts.commands.data import ChartDataCommand from superset.connectors.sqla.models import SqlaTable, TableColumn from superset.errors import SupersetErrorType @@ -57,6 +58,7 @@ ChartDataResultFormat, get_example_database, get_main_database, + AdhocMetricExpressionType, ) @@ -2031,3 +2033,58 @@ def test_chart_data_series_limit(self): self.assertEqual( set(column for column in data[0].keys()), {"state", "name", "sum__num"} ) + + @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices") + def test_chart_data_virtual_table_with_colons(self): + """ + Chart data API: test query with literal colon characters in query, metrics, + where clause and filters + """ + self.login(username="admin") + owner = self.get_user("admin").id + user = db.session.query(security_manager.user_model).get(owner) + + table = SqlaTable( + table_name="virtual_table_1", + schema=get_example_default_schema(), + owners=[user], + database=get_example_database(), + sql="select ':foo' as foo, ':bar:' as bar, state, num from birth_names", + ) + db.session.add(table) + db.session.commit() + table.fetch_metadata() + + request_payload = get_query_context("birth_names") + request_payload["datasource"] = { + "type": "table", + "id": table.id, + } + request_payload["queries"][0]["columns"] = ["foo", "bar", "state"] + request_payload["queries"][0]["where"] = "':abc' != ':xyz:qwerty'" + request_payload["queries"][0]["orderby"] = None + request_payload["queries"][0]["metrics"] = [ + { + "expressionType": AdhocMetricExpressionType.SQL, + "sqlExpression": "sum(case when state = ':asdf' then 0 else 1 end)", + "label": "count", + } + ] + request_payload["queries"][0]["filters"] = [ + {"col": "foo", "op": "!=", "val": ":qwerty:",} + ] + + rv = self.post_assert_metric(CHART_DATA_URI, request_payload, "data") + db.session.delete(table) + db.session.commit() + assert rv.status_code == 200 + response_payload = json.loads(rv.data.decode("utf-8")) + result = response_payload["result"][0] + data = result["data"] + assert {col for col in data[0].keys()} == {"foo", "bar", "state", "count"} + # make sure results and query parameters are unescaped + assert {row["foo"] for row in data} == {":foo"} + assert {row["bar"] for row in data} == {":bar:"} + assert "':asdf'" in result["query"] + assert "':xyz:qwerty'" in result["query"] + assert "':qwerty:'" in result["query"]