Skip to content

Commit

Permalink
feat: add support for comments in adhoc clauses (#19248)
Browse files Browse the repository at this point in the history
* feat: add support for comments in adhoc clauses

* sanitize remaining freeform clauses

* sanitize adhoc having in frontend

* address review comment

(cherry picked from commit f341025)
  • Loading branch information
villebro committed Apr 3, 2022
1 parent f6346d6 commit 408573d
Show file tree
Hide file tree
Showing 9 changed files with 109 additions and 60 deletions.
Expand Up @@ -23,6 +23,14 @@ import { QueryObjectFilterClause } from './types/Query';
import { isSimpleAdhocFilter } from './types/Filter';
import convertFilter from './convertFilter';

function sanitizeClause(clause: string): string {
let sanitizedClause = clause;
if (clause.includes('--')) {
sanitizedClause = `${clause}\n`;
}
return `(${sanitizedClause})`;
}

/** Logic formerly in viz.py's process_query_filters */
export default function processFilters(
formData: Partial<QueryFormData>,
Expand Down Expand Up @@ -60,9 +68,9 @@ export default function processFilters(
});

// some filter-related fields need to go in `extras`
extras.having = freeformHaving.map(exp => `(${exp})`).join(' AND ');
extras.having = freeformHaving.map(sanitizeClause).join(' AND ');
extras.having_druid = simpleHaving;
extras.where = freeformWhere.map(exp => `(${exp})`).join(' AND ');
extras.where = freeformWhere.map(sanitizeClause).join(' AND ');

return {
filters: simpleWhere,
Expand Down
Expand Up @@ -132,12 +132,12 @@ describe('processFilters', () => {
{
expressionType: 'SQL',
clause: 'WHERE',
sqlExpression: 'tea = "jasmine"',
sqlExpression: "tea = 'jasmine'",
},
{
expressionType: 'SQL',
clause: 'WHERE',
sqlExpression: 'cup = "large"',
sqlExpression: "cup = 'large' -- comment",
},
{
expressionType: 'SQL',
Expand All @@ -147,13 +147,13 @@ describe('processFilters', () => {
{
expressionType: 'SQL',
clause: 'HAVING',
sqlExpression: 'waitTime <= 180',
sqlExpression: 'waitTime <= 180 -- comment',
},
],
}),
).toEqual({
extras: {
having: '(ice = 25 OR ice = 50) AND (waitTime <= 180)',
having: '(ice = 25 OR ice = 50) AND (waitTime <= 180 -- comment\n)',
having_druid: [
{
col: 'sweetness',
Expand All @@ -166,7 +166,7 @@ describe('processFilters', () => {
val: '50',
},
],
where: '(tea = "jasmine") AND (cup = "large")',
where: "(tea = 'jasmine') AND (cup = 'large' -- comment\n)",
},
filters: [
{
Expand Down
10 changes: 6 additions & 4 deletions superset/common/query_object.py
Expand Up @@ -30,7 +30,7 @@
QueryClauseValidationException,
QueryObjectValidationError,
)
from superset.sql_parse import validate_filter_clause
from superset.sql_parse import sanitize_clause
from superset.superset_typing import Column, Metric, OrderBy
from superset.utils import pandas_postprocessing
from superset.utils.core import (
Expand Down Expand Up @@ -272,7 +272,7 @@ def validate(
try:
self._validate_there_are_no_missing_series()
self._validate_no_have_duplicate_labels()
self._validate_filters()
self._sanitize_filters()
return None
except QueryObjectValidationError as ex:
if raise_exceptions:
Expand All @@ -291,12 +291,14 @@ def _validate_no_have_duplicate_labels(self) -> None:
)
)

def _validate_filters(self) -> None:
def _sanitize_filters(self) -> None:
for param in ("where", "having"):
clause = self.extras.get(param)
if clause:
try:
validate_filter_clause(clause)
sanitized_clause = sanitize_clause(clause)
if sanitized_clause != clause:
self.extras[param] = sanitized_clause
except QueryClauseValidationException as ex:
raise QueryObjectValidationError(ex.message) from ex

Expand Down
23 changes: 17 additions & 6 deletions superset/connectors/sqla/models.py
Expand Up @@ -82,7 +82,10 @@
)
from superset.datasets.models import Dataset as NewDataset
from superset.db_engine_specs.base import BaseEngineSpec, CTE_ALIAS, TimestampExpression
from superset.exceptions import QueryObjectValidationError
from superset.exceptions import (
QueryClauseValidationException,
QueryObjectValidationError,
)
from superset.jinja_context import (
BaseTemplateProcessor,
ExtraCache,
Expand All @@ -96,7 +99,7 @@
clone_model,
QueryResult,
)
from superset.sql_parse import ParsedQuery
from superset.sql_parse import ParsedQuery, sanitize_clause
from superset.superset_typing import (
AdhocColumn,
AdhocMetric,
Expand Down Expand Up @@ -918,6 +921,10 @@ def adhoc_metric_to_sqla(
tp = self.get_template_processor()
expression = tp.process_template(cast(str, metric["sqlExpression"]))
validate_adhoc_subquery(expression)
try:
expression = sanitize_clause(expression)
except QueryClauseValidationException as ex:
raise QueryObjectValidationError(ex.message) from ex
sqla_metric = literal_column(expression)
else:
raise QueryObjectValidationError("Adhoc metric expressionType is invalid")
Expand All @@ -943,6 +950,10 @@ def adhoc_column_to_sqla(
expression = template_processor.process_template(expression)
if expression:
validate_adhoc_subquery(expression)
try:
expression = sanitize_clause(expression)
except QueryClauseValidationException as ex:
raise QueryObjectValidationError(ex.message) from ex
sqla_metric = literal_column(expression)
return self.make_sqla_column_compatible(sqla_metric, label)

Expand Down Expand Up @@ -1388,27 +1399,27 @@ def get_sqla_query( # pylint: disable=too-many-arguments,too-many-locals,too-ma
where = extras.get("where")
if where:
try:
where = template_processor.process_template(where)
where = template_processor.process_template(f"({where})")
except TemplateError as ex:
raise QueryObjectValidationError(
_(
"Error in jinja expression in WHERE clause: %(msg)s",
msg=ex.message,
)
) from ex
where_clause_and += [self.text(f"({where})")]
where_clause_and += [self.text(where)]
having = extras.get("having")
if having:
try:
having = template_processor.process_template(having)
having = template_processor.process_template(f"({having})")
except TemplateError as ex:
raise QueryObjectValidationError(
_(
"Error in jinja expression in HAVING clause: %(msg)s",
msg=ex.message,
)
) from ex
having_clause_and += [self.text(f"({having})")]
having_clause_and += [self.text(having)]
if apply_fetch_values_predicate and self.fetch_values_predicate:
qry = qry.where(self.get_fetch_values_predicate())
if granularity:
Expand Down
21 changes: 16 additions & 5 deletions superset/sql_parse.py
Expand Up @@ -32,6 +32,7 @@
Where,
)
from sqlparse.tokens import (
Comment,
CTE,
DDL,
DML,
Expand Down Expand Up @@ -441,25 +442,35 @@ def set_or_update_query_limit(self, new_limit: int, force: bool = False) -> str:
return str_res


def validate_filter_clause(clause: str) -> None:
if sqlparse.format(clause, strip_comments=True) != sqlparse.format(clause):
raise QueryClauseValidationException("Filter clause contains comment")

def sanitize_clause(clause: str) -> str:
# clause = sqlparse.format(clause, strip_comments=True)
statements = sqlparse.parse(clause)
if len(statements) != 1:
raise QueryClauseValidationException("Filter clause contains multiple queries")
raise QueryClauseValidationException("Clause contains multiple statements")
open_parens = 0

previous_token = None
for token in statements[0]:
if token.value == "/" and previous_token and previous_token.value == "*":
raise QueryClauseValidationException("Closing unopened multiline comment")
if token.value == "*" and previous_token and previous_token.value == "/":
raise QueryClauseValidationException("Unclosed multiline comment")
if token.value in (")", "("):
open_parens += 1 if token.value == "(" else -1
if open_parens < 0:
raise QueryClauseValidationException(
"Closing unclosed parenthesis in filter clause"
)
previous_token = token
if open_parens > 0:
raise QueryClauseValidationException("Unclosed parenthesis in filter clause")

if previous_token and previous_token.ttype in Comment:
if previous_token.value[-1] != "\n":
clause = f"{clause}\n"

return clause


class InsertRLSState(str, Enum):
"""
Expand Down
7 changes: 5 additions & 2 deletions superset/utils/core.py
Expand Up @@ -98,6 +98,7 @@
SupersetException,
SupersetTimeoutException,
)
from superset.sql_parse import sanitize_clause
from superset.superset_typing import (
AdhocColumn,
AdhocMetric,
Expand Down Expand Up @@ -1382,10 +1383,12 @@ def split_adhoc_filters_into_base_filters( # pylint: disable=invalid-name
}
)
elif expression_type == "SQL":
sql_expression = adhoc_filter.get("sqlExpression")
sql_expression = sanitize_clause(sql_expression)
if clause == "WHERE":
sql_where_filters.append(adhoc_filter.get("sqlExpression"))
sql_where_filters.append(sql_expression)
elif clause == "HAVING":
sql_having_filters.append(adhoc_filter.get("sqlExpression"))
sql_having_filters.append(sql_expression)
form_data["where"] = " AND ".join(
["({})".format(sql) for sql in sql_where_filters]
)
Expand Down
10 changes: 4 additions & 6 deletions superset/viz.py
Expand Up @@ -62,14 +62,13 @@
from superset.exceptions import (
CacheLoadError,
NullValueException,
QueryClauseValidationException,
QueryObjectValidationError,
SpatialException,
SupersetSecurityException,
)
from superset.extensions import cache_manager, security_manager
from superset.models.helpers import QueryResult
from superset.sql_parse import validate_filter_clause
from superset.sql_parse import sanitize_clause
from superset.superset_typing import (
Column,
Metric,
Expand Down Expand Up @@ -391,10 +390,9 @@ def query_obj(self) -> QueryObjectDict: # pylint: disable=too-many-locals
for param in ("where", "having"):
clause = self.form_data.get(param)
if clause:
try:
validate_filter_clause(clause)
except QueryClauseValidationException as ex:
raise QueryObjectValidationError(ex.message) from ex
sanitized_clause = sanitize_clause(clause)
if sanitized_clause != clause:
self.form_data[param] = sanitized_clause

# extras are used to query elements specific to a datasource type
# for instance the extra where clause that applies only to Tables
Expand Down
22 changes: 22 additions & 0 deletions tests/integration_tests/charts/data/api_tests.py
Expand Up @@ -465,6 +465,28 @@ def test_with_invalid_where_parameter_closing_unclosed__400(self):

assert rv.status_code == 400

@pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
def test_with_where_parameter_including_comment___200(self):
self.query_context_payload["queries"][0]["filters"] = []
self.query_context_payload["queries"][0]["extras"]["where"] = "1 = 1 -- abc"

rv = self.post_assert_metric(CHART_DATA_URI, self.query_context_payload, "data")

assert rv.status_code == 200

@pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
def test_with_orderby_parameter_with_second_query__400(self):
self.query_context_payload["queries"][0]["filters"] = []
self.query_context_payload["queries"][0]["orderby"] = [
[
{"expressionType": "SQL", "sqlExpression": "sum__num; select 1, 1",},
True,
],
]
rv = self.post_assert_metric(CHART_DATA_URI, self.query_context_payload, "data")

assert rv.status_code == 400

@pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
def test_with_invalid_having_parameter_closing_and_comment__400(self):
self.query_context_payload["queries"][0]["filters"] = []
Expand Down
54 changes: 24 additions & 30 deletions tests/unit_tests/sql_parse_tests.py
Expand Up @@ -30,9 +30,9 @@
insert_rls,
matches_table_name,
ParsedQuery,
sanitize_clause,
strip_comments_from_sql,
Table,
validate_filter_clause,
)


Expand Down Expand Up @@ -1142,52 +1142,46 @@ def test_strip_comments_from_sql() -> None:
)


def test_validate_filter_clause_valid():
def test_sanitize_clause_valid():
# regular clauses
assert validate_filter_clause("col = 1") is None
assert validate_filter_clause("1=\t\n1") is None
assert validate_filter_clause("(col = 1)") is None
assert validate_filter_clause("(col1 = 1) AND (col2 = 2)") is None
assert sanitize_clause("col = 1") == "col = 1"
assert sanitize_clause("1=\t\n1") == "1=\t\n1"
assert sanitize_clause("(col = 1)") == "(col = 1)"
assert sanitize_clause("(col1 = 1) AND (col2 = 2)") == "(col1 = 1) AND (col2 = 2)"
assert sanitize_clause("col = 'abc' -- comment") == "col = 'abc' -- comment\n"

# Valid literal values that appear to be invalid
assert validate_filter_clause("col = 'col1 = 1) AND (col2 = 2'") is None
assert validate_filter_clause("col = 'select 1; select 2'") is None
assert validate_filter_clause("col = 'abc -- comment'") is None


def test_validate_filter_clause_closing_unclosed():
with pytest.raises(QueryClauseValidationException):
validate_filter_clause("col1 = 1) AND (col2 = 2)")


def test_validate_filter_clause_unclosed():
with pytest.raises(QueryClauseValidationException):
validate_filter_clause("(col1 = 1) AND (col2 = 2")
# Valid literal values that at could be flagged as invalid by a naive query parser
assert (
sanitize_clause("col = 'col1 = 1) AND (col2 = 2'")
== "col = 'col1 = 1) AND (col2 = 2'"
)
assert sanitize_clause("col = 'select 1; select 2'") == "col = 'select 1; select 2'"
assert sanitize_clause("col = 'abc -- comment'") == "col = 'abc -- comment'"


def test_validate_filter_clause_closing_and_unclosed():
def test_sanitize_clause_closing_unclosed():
with pytest.raises(QueryClauseValidationException):
validate_filter_clause("col1 = 1) AND (col2 = 2")
sanitize_clause("col1 = 1) AND (col2 = 2)")


def test_validate_filter_clause_closing_and_unclosed_nested():
def test_sanitize_clause_unclosed():
with pytest.raises(QueryClauseValidationException):
validate_filter_clause("(col1 = 1)) AND ((col2 = 2)")
sanitize_clause("(col1 = 1) AND (col2 = 2")


def test_validate_filter_clause_multiple():
def test_sanitize_clause_closing_and_unclosed():
with pytest.raises(QueryClauseValidationException):
validate_filter_clause("TRUE; SELECT 1")
sanitize_clause("col1 = 1) AND (col2 = 2")


def test_validate_filter_clause_comment():
def test_sanitize_clause_closing_and_unclosed_nested():
with pytest.raises(QueryClauseValidationException):
validate_filter_clause("1 = 1 -- comment")
sanitize_clause("(col1 = 1)) AND ((col2 = 2)")


def test_validate_filter_clause_subquery_comment():
def test_sanitize_clause_multiple():
with pytest.raises(QueryClauseValidationException):
validate_filter_clause("(1 = 1 -- comment\n)")
sanitize_clause("TRUE; SELECT 1")


def test_sqlparse_issue_652():
Expand Down

0 comments on commit 408573d

Please sign in to comment.