From 408573d4d6de1f02d8a46277b49256acb2e00f44 Mon Sep 17 00:00:00 2001 From: Ville Brofeldt <33317356+villebro@users.noreply.github.com> Date: Sat, 19 Mar 2022 00:08:06 +0200 Subject: [PATCH] feat: add support for comments in adhoc clauses (#19248) * feat: add support for comments in adhoc clauses * sanitize remaining freeform clauses * sanitize adhoc having in frontend * address review comment (cherry picked from commit f341025d80aacf7345e7c20f8463231b9197ea58) --- .../src/query/processFilters.ts | 12 ++++- .../test/query/processFilters.test.ts | 10 ++-- superset/common/query_object.py | 10 ++-- superset/connectors/sqla/models.py | 23 +++++--- superset/sql_parse.py | 21 ++++++-- superset/utils/core.py | 7 ++- superset/viz.py | 10 ++-- .../charts/data/api_tests.py | 22 ++++++++ tests/unit_tests/sql_parse_tests.py | 54 +++++++++---------- 9 files changed, 109 insertions(+), 60 deletions(-) diff --git a/superset-frontend/packages/superset-ui-core/src/query/processFilters.ts b/superset-frontend/packages/superset-ui-core/src/query/processFilters.ts index 8ead77c0fc34b..239f1c49afbe5 100644 --- a/superset-frontend/packages/superset-ui-core/src/query/processFilters.ts +++ b/superset-frontend/packages/superset-ui-core/src/query/processFilters.ts @@ -23,6 +23,14 @@ import { QueryObjectFilterClause } from './types/Query'; import { isSimpleAdhocFilter } from './types/Filter'; import convertFilter from './convertFilter'; +function sanitizeClause(clause: string): string { + let sanitizedClause = clause; + if (clause.includes('--')) { + sanitizedClause = `${clause}\n`; + } + return `(${sanitizedClause})`; +} + /** Logic formerly in viz.py's process_query_filters */ export default function processFilters( formData: Partial, @@ -60,9 +68,9 @@ export default function processFilters( }); // some filter-related fields need to go in `extras` - extras.having = freeformHaving.map(exp => `(${exp})`).join(' AND '); + extras.having = freeformHaving.map(sanitizeClause).join(' AND '); extras.having_druid = simpleHaving; - extras.where = freeformWhere.map(exp => `(${exp})`).join(' AND '); + extras.where = freeformWhere.map(sanitizeClause).join(' AND '); return { filters: simpleWhere, diff --git a/superset-frontend/packages/superset-ui-core/test/query/processFilters.test.ts b/superset-frontend/packages/superset-ui-core/test/query/processFilters.test.ts index 267b416493e35..151c0363f16f0 100644 --- a/superset-frontend/packages/superset-ui-core/test/query/processFilters.test.ts +++ b/superset-frontend/packages/superset-ui-core/test/query/processFilters.test.ts @@ -132,12 +132,12 @@ describe('processFilters', () => { { expressionType: 'SQL', clause: 'WHERE', - sqlExpression: 'tea = "jasmine"', + sqlExpression: "tea = 'jasmine'", }, { expressionType: 'SQL', clause: 'WHERE', - sqlExpression: 'cup = "large"', + sqlExpression: "cup = 'large' -- comment", }, { expressionType: 'SQL', @@ -147,13 +147,13 @@ describe('processFilters', () => { { expressionType: 'SQL', clause: 'HAVING', - sqlExpression: 'waitTime <= 180', + sqlExpression: 'waitTime <= 180 -- comment', }, ], }), ).toEqual({ extras: { - having: '(ice = 25 OR ice = 50) AND (waitTime <= 180)', + having: '(ice = 25 OR ice = 50) AND (waitTime <= 180 -- comment\n)', having_druid: [ { col: 'sweetness', @@ -166,7 +166,7 @@ describe('processFilters', () => { val: '50', }, ], - where: '(tea = "jasmine") AND (cup = "large")', + where: "(tea = 'jasmine') AND (cup = 'large' -- comment\n)", }, filters: [ { diff --git a/superset/common/query_object.py b/superset/common/query_object.py index fd988a36fac05..139dc27c580ef 100644 --- a/superset/common/query_object.py +++ b/superset/common/query_object.py @@ -30,7 +30,7 @@ QueryClauseValidationException, QueryObjectValidationError, ) -from superset.sql_parse import validate_filter_clause +from superset.sql_parse import sanitize_clause from superset.superset_typing import Column, Metric, OrderBy from superset.utils import pandas_postprocessing from superset.utils.core import ( @@ -272,7 +272,7 @@ def validate( try: self._validate_there_are_no_missing_series() self._validate_no_have_duplicate_labels() - self._validate_filters() + self._sanitize_filters() return None except QueryObjectValidationError as ex: if raise_exceptions: @@ -291,12 +291,14 @@ def _validate_no_have_duplicate_labels(self) -> None: ) ) - def _validate_filters(self) -> None: + def _sanitize_filters(self) -> None: for param in ("where", "having"): clause = self.extras.get(param) if clause: try: - validate_filter_clause(clause) + sanitized_clause = sanitize_clause(clause) + if sanitized_clause != clause: + self.extras[param] = sanitized_clause except QueryClauseValidationException as ex: raise QueryObjectValidationError(ex.message) from ex diff --git a/superset/connectors/sqla/models.py b/superset/connectors/sqla/models.py index 8702d5c1a7506..4382dda48d4fa 100644 --- a/superset/connectors/sqla/models.py +++ b/superset/connectors/sqla/models.py @@ -82,7 +82,10 @@ ) from superset.datasets.models import Dataset as NewDataset from superset.db_engine_specs.base import BaseEngineSpec, CTE_ALIAS, TimestampExpression -from superset.exceptions import QueryObjectValidationError +from superset.exceptions import ( + QueryClauseValidationException, + QueryObjectValidationError, +) from superset.jinja_context import ( BaseTemplateProcessor, ExtraCache, @@ -96,7 +99,7 @@ clone_model, QueryResult, ) -from superset.sql_parse import ParsedQuery +from superset.sql_parse import ParsedQuery, sanitize_clause from superset.superset_typing import ( AdhocColumn, AdhocMetric, @@ -918,6 +921,10 @@ def adhoc_metric_to_sqla( tp = self.get_template_processor() expression = tp.process_template(cast(str, metric["sqlExpression"])) validate_adhoc_subquery(expression) + try: + expression = sanitize_clause(expression) + except QueryClauseValidationException as ex: + raise QueryObjectValidationError(ex.message) from ex sqla_metric = literal_column(expression) else: raise QueryObjectValidationError("Adhoc metric expressionType is invalid") @@ -943,6 +950,10 @@ def adhoc_column_to_sqla( expression = template_processor.process_template(expression) if expression: validate_adhoc_subquery(expression) + try: + expression = sanitize_clause(expression) + except QueryClauseValidationException as ex: + raise QueryObjectValidationError(ex.message) from ex sqla_metric = literal_column(expression) return self.make_sqla_column_compatible(sqla_metric, label) @@ -1388,7 +1399,7 @@ def get_sqla_query( # pylint: disable=too-many-arguments,too-many-locals,too-ma where = extras.get("where") if where: try: - where = template_processor.process_template(where) + where = template_processor.process_template(f"({where})") except TemplateError as ex: raise QueryObjectValidationError( _( @@ -1396,11 +1407,11 @@ def get_sqla_query( # pylint: disable=too-many-arguments,too-many-locals,too-ma msg=ex.message, ) ) from ex - where_clause_and += [self.text(f"({where})")] + where_clause_and += [self.text(where)] having = extras.get("having") if having: try: - having = template_processor.process_template(having) + having = template_processor.process_template(f"({having})") except TemplateError as ex: raise QueryObjectValidationError( _( @@ -1408,7 +1419,7 @@ def get_sqla_query( # pylint: disable=too-many-arguments,too-many-locals,too-ma msg=ex.message, ) ) from ex - having_clause_and += [self.text(f"({having})")] + having_clause_and += [self.text(having)] if apply_fetch_values_predicate and self.fetch_values_predicate: qry = qry.where(self.get_fetch_values_predicate()) if granularity: diff --git a/superset/sql_parse.py b/superset/sql_parse.py index f5523bab71e8d..95361b39a6a27 100644 --- a/superset/sql_parse.py +++ b/superset/sql_parse.py @@ -32,6 +32,7 @@ Where, ) from sqlparse.tokens import ( + Comment, CTE, DDL, DML, @@ -441,25 +442,35 @@ def set_or_update_query_limit(self, new_limit: int, force: bool = False) -> str: return str_res -def validate_filter_clause(clause: str) -> None: - if sqlparse.format(clause, strip_comments=True) != sqlparse.format(clause): - raise QueryClauseValidationException("Filter clause contains comment") - +def sanitize_clause(clause: str) -> str: + # clause = sqlparse.format(clause, strip_comments=True) statements = sqlparse.parse(clause) if len(statements) != 1: - raise QueryClauseValidationException("Filter clause contains multiple queries") + raise QueryClauseValidationException("Clause contains multiple statements") open_parens = 0 + previous_token = None for token in statements[0]: + if token.value == "/" and previous_token and previous_token.value == "*": + raise QueryClauseValidationException("Closing unopened multiline comment") + if token.value == "*" and previous_token and previous_token.value == "/": + raise QueryClauseValidationException("Unclosed multiline comment") if token.value in (")", "("): open_parens += 1 if token.value == "(" else -1 if open_parens < 0: raise QueryClauseValidationException( "Closing unclosed parenthesis in filter clause" ) + previous_token = token if open_parens > 0: raise QueryClauseValidationException("Unclosed parenthesis in filter clause") + if previous_token and previous_token.ttype in Comment: + if previous_token.value[-1] != "\n": + clause = f"{clause}\n" + + return clause + class InsertRLSState(str, Enum): """ diff --git a/superset/utils/core.py b/superset/utils/core.py index a18df6dea8a9a..b348137802ee7 100644 --- a/superset/utils/core.py +++ b/superset/utils/core.py @@ -98,6 +98,7 @@ SupersetException, SupersetTimeoutException, ) +from superset.sql_parse import sanitize_clause from superset.superset_typing import ( AdhocColumn, AdhocMetric, @@ -1382,10 +1383,12 @@ def split_adhoc_filters_into_base_filters( # pylint: disable=invalid-name } ) elif expression_type == "SQL": + sql_expression = adhoc_filter.get("sqlExpression") + sql_expression = sanitize_clause(sql_expression) if clause == "WHERE": - sql_where_filters.append(adhoc_filter.get("sqlExpression")) + sql_where_filters.append(sql_expression) elif clause == "HAVING": - sql_having_filters.append(adhoc_filter.get("sqlExpression")) + sql_having_filters.append(sql_expression) form_data["where"] = " AND ".join( ["({})".format(sql) for sql in sql_where_filters] ) diff --git a/superset/viz.py b/superset/viz.py index 02ad2855e0f6c..21afcf038c94b 100644 --- a/superset/viz.py +++ b/superset/viz.py @@ -62,14 +62,13 @@ from superset.exceptions import ( CacheLoadError, NullValueException, - QueryClauseValidationException, QueryObjectValidationError, SpatialException, SupersetSecurityException, ) from superset.extensions import cache_manager, security_manager from superset.models.helpers import QueryResult -from superset.sql_parse import validate_filter_clause +from superset.sql_parse import sanitize_clause from superset.superset_typing import ( Column, Metric, @@ -391,10 +390,9 @@ def query_obj(self) -> QueryObjectDict: # pylint: disable=too-many-locals for param in ("where", "having"): clause = self.form_data.get(param) if clause: - try: - validate_filter_clause(clause) - except QueryClauseValidationException as ex: - raise QueryObjectValidationError(ex.message) from ex + sanitized_clause = sanitize_clause(clause) + if sanitized_clause != clause: + self.form_data[param] = sanitized_clause # extras are used to query elements specific to a datasource type # for instance the extra where clause that applies only to Tables diff --git a/tests/integration_tests/charts/data/api_tests.py b/tests/integration_tests/charts/data/api_tests.py index b09123f7fa68e..ca671cad91803 100644 --- a/tests/integration_tests/charts/data/api_tests.py +++ b/tests/integration_tests/charts/data/api_tests.py @@ -465,6 +465,28 @@ def test_with_invalid_where_parameter_closing_unclosed__400(self): assert rv.status_code == 400 + @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices") + def test_with_where_parameter_including_comment___200(self): + self.query_context_payload["queries"][0]["filters"] = [] + self.query_context_payload["queries"][0]["extras"]["where"] = "1 = 1 -- abc" + + rv = self.post_assert_metric(CHART_DATA_URI, self.query_context_payload, "data") + + assert rv.status_code == 200 + + @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices") + def test_with_orderby_parameter_with_second_query__400(self): + self.query_context_payload["queries"][0]["filters"] = [] + self.query_context_payload["queries"][0]["orderby"] = [ + [ + {"expressionType": "SQL", "sqlExpression": "sum__num; select 1, 1",}, + True, + ], + ] + rv = self.post_assert_metric(CHART_DATA_URI, self.query_context_payload, "data") + + assert rv.status_code == 400 + @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices") def test_with_invalid_having_parameter_closing_and_comment__400(self): self.query_context_payload["queries"][0]["filters"] = [] diff --git a/tests/unit_tests/sql_parse_tests.py b/tests/unit_tests/sql_parse_tests.py index 886eb368e4aa4..75f099e52b6e1 100644 --- a/tests/unit_tests/sql_parse_tests.py +++ b/tests/unit_tests/sql_parse_tests.py @@ -30,9 +30,9 @@ insert_rls, matches_table_name, ParsedQuery, + sanitize_clause, strip_comments_from_sql, Table, - validate_filter_clause, ) @@ -1142,52 +1142,46 @@ def test_strip_comments_from_sql() -> None: ) -def test_validate_filter_clause_valid(): +def test_sanitize_clause_valid(): # regular clauses - assert validate_filter_clause("col = 1") is None - assert validate_filter_clause("1=\t\n1") is None - assert validate_filter_clause("(col = 1)") is None - assert validate_filter_clause("(col1 = 1) AND (col2 = 2)") is None + assert sanitize_clause("col = 1") == "col = 1" + assert sanitize_clause("1=\t\n1") == "1=\t\n1" + assert sanitize_clause("(col = 1)") == "(col = 1)" + assert sanitize_clause("(col1 = 1) AND (col2 = 2)") == "(col1 = 1) AND (col2 = 2)" + assert sanitize_clause("col = 'abc' -- comment") == "col = 'abc' -- comment\n" - # Valid literal values that appear to be invalid - assert validate_filter_clause("col = 'col1 = 1) AND (col2 = 2'") is None - assert validate_filter_clause("col = 'select 1; select 2'") is None - assert validate_filter_clause("col = 'abc -- comment'") is None - - -def test_validate_filter_clause_closing_unclosed(): - with pytest.raises(QueryClauseValidationException): - validate_filter_clause("col1 = 1) AND (col2 = 2)") - - -def test_validate_filter_clause_unclosed(): - with pytest.raises(QueryClauseValidationException): - validate_filter_clause("(col1 = 1) AND (col2 = 2") + # Valid literal values that at could be flagged as invalid by a naive query parser + assert ( + sanitize_clause("col = 'col1 = 1) AND (col2 = 2'") + == "col = 'col1 = 1) AND (col2 = 2'" + ) + assert sanitize_clause("col = 'select 1; select 2'") == "col = 'select 1; select 2'" + assert sanitize_clause("col = 'abc -- comment'") == "col = 'abc -- comment'" -def test_validate_filter_clause_closing_and_unclosed(): +def test_sanitize_clause_closing_unclosed(): with pytest.raises(QueryClauseValidationException): - validate_filter_clause("col1 = 1) AND (col2 = 2") + sanitize_clause("col1 = 1) AND (col2 = 2)") -def test_validate_filter_clause_closing_and_unclosed_nested(): +def test_sanitize_clause_unclosed(): with pytest.raises(QueryClauseValidationException): - validate_filter_clause("(col1 = 1)) AND ((col2 = 2)") + sanitize_clause("(col1 = 1) AND (col2 = 2") -def test_validate_filter_clause_multiple(): +def test_sanitize_clause_closing_and_unclosed(): with pytest.raises(QueryClauseValidationException): - validate_filter_clause("TRUE; SELECT 1") + sanitize_clause("col1 = 1) AND (col2 = 2") -def test_validate_filter_clause_comment(): +def test_sanitize_clause_closing_and_unclosed_nested(): with pytest.raises(QueryClauseValidationException): - validate_filter_clause("1 = 1 -- comment") + sanitize_clause("(col1 = 1)) AND ((col2 = 2)") -def test_validate_filter_clause_subquery_comment(): +def test_sanitize_clause_multiple(): with pytest.raises(QueryClauseValidationException): - validate_filter_clause("(1 = 1 -- comment\n)") + sanitize_clause("TRUE; SELECT 1") def test_sqlparse_issue_652():