From cd55b411a7e02c8f88bd2fe294edd14c33987645 Mon Sep 17 00:00:00 2001 From: Ville Brofeldt Date: Fri, 12 Nov 2021 14:24:12 +0200 Subject: [PATCH] fix: avoid escaping bind-like params containing colons --- superset/connectors/sqla/models.py | 25 +++++---- superset/utils/core.py | 26 --------- tests/integration_tests/charts/api_tests.py | 59 ++++++++++++++++++++- 3 files changed, 74 insertions(+), 36 deletions(-) diff --git a/superset/connectors/sqla/models.py b/superset/connectors/sqla/models.py index 376751a8515b..8127ce598dd4 100644 --- a/superset/connectors/sqla/models.py +++ b/superset/connectors/sqla/models.py @@ -65,7 +65,7 @@ from sqlalchemy.orm import backref, Query, relationship, RelationshipProperty, Session from sqlalchemy.orm.mapper import Mapper from sqlalchemy.schema import UniqueConstraint -from sqlalchemy.sql import column, ColumnElement, literal_column, table, text +from sqlalchemy.sql import column, ColumnElement, literal_column, table from sqlalchemy.sql.elements import ColumnClause, TextClause from sqlalchemy.sql.expression import Label, Select, TextAsFrom from sqlalchemy.sql.selectable import Alias, TableClause @@ -103,6 +103,16 @@ VIRTUAL_TABLE_ALIAS = "virtual_table" +def text(clause: str) -> TextClause: + """ + SQLALchemy wrapper to ensure text clauses are escaped properly + + :param clause: clause potentially containing colons + :return: text clause with escaped colons + """ + return sa.text(clause.replace(":", "\\:")) + + class SqlaQuery(NamedTuple): applied_template_filters: List[str] extra_cache_keys: List[Any] @@ -806,7 +816,7 @@ def get_from_clause( raise QueryObjectValidationError( _("Virtual dataset query must be read-only") ) - return TextAsFrom(sa.text(from_sql), []).alias(VIRTUAL_TABLE_ALIAS) + return TextAsFrom(text(from_sql), []).alias(VIRTUAL_TABLE_ALIAS) def get_rendered_sql( self, template_processor: Optional[BaseTemplateProcessor] = None @@ -827,8 +837,6 @@ def get_rendered_sql( ) ) from ex sql = sqlparse.format(sql.strip("\t\r\n; "), strip_comments=True) - # we need to escape strings that SQLAlchemy interprets as bind parameters - sql = utils.escape_sqla_query_binds(sql) if not sql: raise QueryObjectValidationError(_("Virtual dataset query cannot be empty")) if len(sqlparse.split(sql)) > 1: @@ -930,9 +938,8 @@ def _get_sqla_row_level_filters( filters_grouped: Dict[Union[int, str], List[str]] = defaultdict(list) try: for filter_ in security_manager.get_rls_filters(self): - clause = text( - f"({template_processor.process_template(filter_.clause)})" - ) + rendered_filter = template_processor.process_template(filter_.clause) + clause = text(f"({rendered_filter})") filters_grouped[filter_.group_key or filter_.id].append(clause) return [or_(*clauses) for clauses in filters_grouped.values()] except TemplateError as ex: @@ -1286,7 +1293,7 @@ def get_sqla_query( # pylint: disable=too-many-arguments,too-many-locals,too-ma msg=ex.message, ) ) from ex - where_clause_and += [sa.text("({})".format(where))] + where_clause_and += [text(f"({where})")] having = extras.get("having") if having: try: @@ -1298,7 +1305,7 @@ def get_sqla_query( # pylint: disable=too-many-arguments,too-many-locals,too-ma msg=ex.message, ) ) from ex - having_clause_and += [sa.text("({})".format(having))] + having_clause_and += [text(f"({having})")] if apply_fetch_values_predicate and self.fetch_values_predicate: qry = qry.where(self.get_fetch_values_predicate()) if granularity: diff --git a/superset/utils/core.py b/superset/utils/core.py index 625423e1e05c..ed9e40104a9e 100644 --- a/superset/utils/core.py +++ b/superset/utils/core.py @@ -1771,29 +1771,3 @@ def apply_max_row_limit(limit: int, max_limit: Optional[int] = None,) -> int: if limit != 0: return min(max_limit, limit) return max_limit - - -def escape_sqla_query_binds(sql: str) -> str: - """ - Replace strings in a query that SQLAlchemy would otherwise interpret as - bind parameters. - - :param sql: unescaped query string - :return: escaped query string - >>> escape_sqla_query_binds("select ':foo'") - "select '\\\\:foo'" - >>> escape_sqla_query_binds("select 'foo'::TIMESTAMP") - "select 'foo'::TIMESTAMP" - >>> escape_sqla_query_binds("select ':foo :bar'::TIMESTAMP") - "select '\\\\:foo \\\\:bar'::TIMESTAMP" - >>> escape_sqla_query_binds("select ':foo :foo :bar'::TIMESTAMP") - "select '\\\\:foo \\\\:foo \\\\:bar'::TIMESTAMP" - """ - matches = BIND_PARAM_REGEX.finditer(sql) - processed_binds = set() - for match in matches: - bind = match.group(0) - if bind not in processed_binds: - sql = sql.replace(bind, bind.replace(":", "\\:")) - processed_binds.add(bind) - return sql diff --git a/tests/integration_tests/charts/api_tests.py b/tests/integration_tests/charts/api_tests.py index 696b52154bff..e9a87d45ffc8 100644 --- a/tests/integration_tests/charts/api_tests.py +++ b/tests/integration_tests/charts/api_tests.py @@ -20,7 +20,7 @@ import unittest from datetime import datetime from io import BytesIO -from typing import Optional +from typing import Optional, List from unittest import mock from zipfile import is_zipfile, ZipFile @@ -42,6 +42,7 @@ load_world_bank_dashboard_with_slices, ) from tests.integration_tests.test_app import app +from superset import security_manager from superset.charts.commands.data import ChartDataCommand from superset.connectors.sqla.models import SqlaTable, TableColumn from superset.errors import SupersetErrorType @@ -56,6 +57,7 @@ get_example_database, get_example_default_schema, get_main_database, + AdhocMetricExpressionType, ) from superset.common.chart_data import ChartDataResultFormat, ChartDataResultType @@ -2033,3 +2035,58 @@ def test_chart_data_series_limit(self): self.assertEqual( set(column for column in data[0].keys()), {"state", "name", "sum__num"} ) + + @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices") + def test_chart_data_virtual_table_with_colons(self): + """ + Chart data API: test query with literal colon characters in query, metrics, + where clause and filters + """ + self.login(username="admin") + owner = self.get_user("admin").id + user = db.session.query(security_manager.user_model).get(owner) + + table = SqlaTable( + table_name="virtual_table_1", + schema=get_example_default_schema(), + owners=[user], + database=get_example_database(), + sql="select ':foo' as foo, ':bar:' as bar, * from birth_names", + ) + db.session.add(table) + db.session.commit() + table.fetch_metadata() + + request_payload = get_query_context("birth_names") + request_payload["datasource"] = { + "type": "table", + "id": table.id, + } + request_payload["queries"][0]["columns"] = ["foo", "bar", "state"] + request_payload["queries"][0]["where"] = "':abc' != ':xyz:qwerty'" + request_payload["queries"][0]["orderby"] = None + request_payload["queries"][0]["metrics"] = [ + { + "expressionType": AdhocMetricExpressionType.SQL, + "sqlExpression": "sum(case when state = ':asdf' then 0 else 1 end)", + "label": "count", + } + ] + request_payload["queries"][0]["filters"] = [ + {"col": "foo", "op": "!=", "val": ":qwerty:",} + ] + + rv = self.post_assert_metric(CHART_DATA_URI, request_payload, "data") + db.session.delete(table) + db.session.commit() + self.assertEqual(rv.status_code, 200) + response_payload = json.loads(rv.data.decode("utf-8")) + result = response_payload["result"][0] + data = result["data"] + assert {col for col in data[0].keys()} == {"foo", "bar", "state", "count"} + # make sure results and query parameters are unescaped + assert {row["foo"] for row in data} == {":foo"} + assert {row["bar"] for row in data} == {":bar:"} + assert "':asdf'" in result["query"] + assert "':xyz:qwerty'" in result["query"] + assert "':qwerty:'" in result["query"]