Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: avoid escaping bind-like params containing colons #17419

Merged
merged 3 commits into from
Nov 13, 2021
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 16 additions & 9 deletions superset/connectors/sqla/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@
from sqlalchemy.orm import backref, Query, relationship, RelationshipProperty, Session
from sqlalchemy.orm.mapper import Mapper
from sqlalchemy.schema import UniqueConstraint
from sqlalchemy.sql import column, ColumnElement, literal_column, table, text
from sqlalchemy.sql import column, ColumnElement, literal_column, table
from sqlalchemy.sql.elements import ColumnClause, TextClause
from sqlalchemy.sql.expression import Label, Select, TextAsFrom
from sqlalchemy.sql.selectable import Alias, TableClause
Expand Down Expand Up @@ -103,6 +103,16 @@
VIRTUAL_TABLE_ALIAS = "virtual_table"


def text(clause: str) -> TextClause:
"""
SQLALchemy wrapper to ensure text clauses are escaped properly

:param clause: clause potentially containing colons
:return: text clause with escaped colons
"""
return sa.text(clause.replace(":", "\\:"))


class SqlaQuery(NamedTuple):
applied_template_filters: List[str]
extra_cache_keys: List[Any]
Expand Down Expand Up @@ -806,7 +816,7 @@ def get_from_clause(
raise QueryObjectValidationError(
_("Virtual dataset query must be read-only")
)
return TextAsFrom(sa.text(from_sql), []).alias(VIRTUAL_TABLE_ALIAS)
return TextAsFrom(text(from_sql), []).alias(VIRTUAL_TABLE_ALIAS)

def get_rendered_sql(
self, template_processor: Optional[BaseTemplateProcessor] = None
Expand All @@ -827,8 +837,6 @@ def get_rendered_sql(
)
) from ex
sql = sqlparse.format(sql.strip("\t\r\n; "), strip_comments=True)
# we need to escape strings that SQLAlchemy interprets as bind parameters
sql = utils.escape_sqla_query_binds(sql)
if not sql:
raise QueryObjectValidationError(_("Virtual dataset query cannot be empty"))
if len(sqlparse.split(sql)) > 1:
Expand Down Expand Up @@ -930,9 +938,8 @@ def _get_sqla_row_level_filters(
filters_grouped: Dict[Union[int, str], List[str]] = defaultdict(list)
try:
for filter_ in security_manager.get_rls_filters(self):
clause = text(
f"({template_processor.process_template(filter_.clause)})"
)
rendered_filter = template_processor.process_template(filter_.clause)
clause = text(f"({rendered_filter})")
villebro marked this conversation as resolved.
Show resolved Hide resolved
filters_grouped[filter_.group_key or filter_.id].append(clause)
return [or_(*clauses) for clauses in filters_grouped.values()]
except TemplateError as ex:
Expand Down Expand Up @@ -1286,7 +1293,7 @@ def get_sqla_query( # pylint: disable=too-many-arguments,too-many-locals,too-ma
msg=ex.message,
)
) from ex
where_clause_and += [sa.text("({})".format(where))]
where_clause_and += [text(f"({where})")]
having = extras.get("having")
if having:
try:
Expand All @@ -1298,7 +1305,7 @@ def get_sqla_query( # pylint: disable=too-many-arguments,too-many-locals,too-ma
msg=ex.message,
)
) from ex
having_clause_and += [sa.text("({})".format(having))]
having_clause_and += [text(f"({having})")]
if apply_fetch_values_predicate and self.fetch_values_predicate:
qry = qry.where(self.get_fetch_values_predicate())
if granularity:
Expand Down
26 changes: 0 additions & 26 deletions superset/utils/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -1771,29 +1771,3 @@ def apply_max_row_limit(limit: int, max_limit: Optional[int] = None,) -> int:
if limit != 0:
return min(max_limit, limit)
return max_limit


def escape_sqla_query_binds(sql: str) -> str:
"""
Replace strings in a query that SQLAlchemy would otherwise interpret as
bind parameters.

:param sql: unescaped query string
:return: escaped query string
>>> escape_sqla_query_binds("select ':foo'")
"select '\\\\:foo'"
>>> escape_sqla_query_binds("select 'foo'::TIMESTAMP")
"select 'foo'::TIMESTAMP"
>>> escape_sqla_query_binds("select ':foo :bar'::TIMESTAMP")
"select '\\\\:foo \\\\:bar'::TIMESTAMP"
>>> escape_sqla_query_binds("select ':foo :foo :bar'::TIMESTAMP")
"select '\\\\:foo \\\\:foo \\\\:bar'::TIMESTAMP"
"""
matches = BIND_PARAM_REGEX.finditer(sql)
processed_binds = set()
for match in matches:
bind = match.group(0)
if bind not in processed_binds:
sql = sql.replace(bind, bind.replace(":", "\\:"))
processed_binds.add(bind)
return sql
59 changes: 58 additions & 1 deletion tests/integration_tests/charts/api_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
import unittest
from datetime import datetime
from io import BytesIO
from typing import Optional
from typing import Optional, List
from unittest import mock
from zipfile import is_zipfile, ZipFile

Expand All @@ -42,6 +42,7 @@
load_world_bank_dashboard_with_slices,
)
from tests.integration_tests.test_app import app
from superset import security_manager
from superset.charts.commands.data import ChartDataCommand
from superset.connectors.sqla.models import SqlaTable, TableColumn
from superset.errors import SupersetErrorType
Expand All @@ -56,6 +57,7 @@
get_example_database,
get_example_default_schema,
get_main_database,
AdhocMetricExpressionType,
)
from superset.common.chart_data import ChartDataResultFormat, ChartDataResultType

Expand Down Expand Up @@ -2033,3 +2035,58 @@ def test_chart_data_series_limit(self):
self.assertEqual(
set(column for column in data[0].keys()), {"state", "name", "sum__num"}
)

@pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
def test_chart_data_virtual_table_with_colons(self):
"""
Chart data API: test query with literal colon characters in query, metrics,
where clause and filters
"""
self.login(username="admin")
owner = self.get_user("admin").id
user = db.session.query(security_manager.user_model).get(owner)

table = SqlaTable(
table_name="virtual_table_1",
schema=get_example_default_schema(),
owners=[user],
database=get_example_database(),
sql="select ':foo' as foo, ':bar:' as bar, state, num from birth_names",
)
db.session.add(table)
db.session.commit()
table.fetch_metadata()

request_payload = get_query_context("birth_names")
request_payload["datasource"] = {
"type": "table",
"id": table.id,
}
request_payload["queries"][0]["columns"] = ["foo", "bar", "state"]
request_payload["queries"][0]["where"] = "':abc' != ':xyz:qwerty'"
request_payload["queries"][0]["orderby"] = None
request_payload["queries"][0]["metrics"] = [
{
"expressionType": AdhocMetricExpressionType.SQL,
"sqlExpression": "sum(case when state = ':asdf' then 0 else 1 end)",
"label": "count",
}
]
request_payload["queries"][0]["filters"] = [
{"col": "foo", "op": "!=", "val": ":qwerty:",}
]

rv = self.post_assert_metric(CHART_DATA_URI, request_payload, "data")
db.session.delete(table)
db.session.commit()
self.assertEqual(rv.status_code, 200)
villebro marked this conversation as resolved.
Show resolved Hide resolved
response_payload = json.loads(rv.data.decode("utf-8"))
result = response_payload["result"][0]
data = result["data"]
assert {col for col in data[0].keys()} == {"foo", "bar", "state", "count"}
# make sure results and query parameters are unescaped
assert {row["foo"] for row in data} == {":foo"}
assert {row["bar"] for row in data} == {":bar:"}
assert "':asdf'" in result["query"]
assert "':xyz:qwerty'" in result["query"]
assert "':qwerty:'" in result["query"]