diff --git a/superset-frontend/src/SqlLab/actions/sqlLab.js b/superset-frontend/src/SqlLab/actions/sqlLab.js index e4d4db17e9cb..b950d0e37737 100644 --- a/superset-frontend/src/SqlLab/actions/sqlLab.js +++ b/superset-frontend/src/SqlLab/actions/sqlLab.js @@ -369,24 +369,19 @@ export function validateQuery(query) { dispatch(startQueryValidation(query)); const postPayload = { - client_id: query.id, - database_id: query.dbId, - json: true, schema: query.schema, sql: query.sql, - sql_editor_id: query.sqlEditorId, - templateParams: query.templateParams, - validate_only: true, + template_params: query.templateParams, }; return SupersetClient.post({ - endpoint: `/superset/validate_sql_json/${window.location.search}`, - postPayload, - stringify: false, + endpoint: `/api/v1/database/${query.dbId}/validate_sql`, + body: JSON.stringify(postPayload), + headers: { 'Content-Type': 'application/json' }, }) - .then(({ json }) => dispatch(queryValidationReturned(query, json))) + .then(({ json }) => dispatch(queryValidationReturned(query, json.result))) .catch(response => - getClientErrorObject(response).then(error => { + getClientErrorObject(response.result).then(error => { let message = error.error || error.statusText || t('Unknown error'); if (message.includes('CSRF token')) { message = t(COMMON_ERR_MESSAGES.SESSION_TIMED_OUT); diff --git a/superset/constants.py b/superset/constants.py index 821562256a15..98ce7c5d112f 100644 --- a/superset/constants.py +++ b/superset/constants.py @@ -127,6 +127,7 @@ class RouteMethod: # pylint: disable=too-few-public-methods "get_datasets": "read", "function_names": "read", "available": "read", + "validate_sql": "read", "get_data": "read", } diff --git a/superset/databases/api.py b/superset/databases/api.py index 3833edab558f..1afa71c6f056 100644 --- a/superset/databases/api.py +++ b/superset/databases/api.py @@ -49,6 +49,7 @@ from superset.databases.commands.test_connection import TestConnectionDatabaseCommand from superset.databases.commands.update import UpdateDatabaseCommand from superset.databases.commands.validate import ValidateDatabaseParametersCommand +from superset.databases.commands.validate_sql import ValidateSQLCommand from superset.databases.dao import DatabaseDAO from superset.databases.decorators import check_datasource_access from superset.databases.filters import DatabaseFilter, DatabaseUploadEnabledFilter @@ -65,6 +66,8 @@ SelectStarResponseSchema, TableExtraMetadataResponseSchema, TableMetadataResponseSchema, + ValidateSQLRequest, + ValidateSQLResponse, ) from superset.databases.utils import get_table_metadata from superset.db_engine_specs import get_available_engine_specs @@ -98,6 +101,7 @@ class DatabaseRestApi(BaseSupersetModelRestApi): "function_names", "available", "validate_parameters", + "validate_sql", } resource_name = "database" class_permission_name = "Database" @@ -193,6 +197,7 @@ class DatabaseRestApi(BaseSupersetModelRestApi): "database_schemas_query_schema": database_schemas_query_schema, "get_export_ids_schema": get_export_ids_schema, } + openapi_spec_tag = "Database" openapi_spec_component_schemas = ( DatabaseFunctionNamesResponse, @@ -203,6 +208,8 @@ class DatabaseRestApi(BaseSupersetModelRestApi): TableMetadataResponseSchema, SelectStarResponseSchema, SchemasResponseSchema, + ValidateSQLRequest, + ValidateSQLResponse, ) @expose("/", methods=["POST"]) @@ -771,6 +778,66 @@ def related_objects(self, pk: int) -> Response: }, ) + @expose("//validate_sql", methods=["POST"]) + @protect() + @statsd_metrics + @event_logger.log_this_with_context( + action=lambda self, *args, **kwargs: f"{self.__class__.__name__}.validate_sql", + log_to_statsd=False, + ) + def validate_sql(self, pk: int) -> FlaskResponse: + """ + --- + post: + summary: >- + Validates that arbitrary sql is acceptable for the given database + description: >- + Validates arbitrary SQL. + parameters: + - in: path + schema: + type: integer + name: pk + requestBody: + description: Validate SQL request + required: true + content: + application/json: + schema: + $ref: '#/components/schemas/ValidateSQLRequest' + responses: + 200: + description: Validation result + content: + application/json: + schema: + type: object + properties: + result: + description: >- + A List of SQL errors found on the statement + type: array + items: + $ref: '#/components/schemas/ValidateSQLResponse' + 400: + $ref: '#/components/responses/400' + 401: + $ref: '#/components/responses/401' + 404: + $ref: '#/components/responses/404' + 500: + $ref: '#/components/responses/500' + """ + try: + sql_request = ValidateSQLRequest().load(request.json) + except ValidationError as error: + return self.response_400(message=error.messages) + try: + validator_errors = ValidateSQLCommand(pk, sql_request).run() + return self.response(200, result=validator_errors) + except DatabaseNotFoundError: + return self.response_404() + @expose("/export/", methods=["GET"]) @protect() @safe diff --git a/superset/databases/commands/exceptions.py b/superset/databases/commands/exceptions.py index bde76c021c88..a49abd3449d0 100644 --- a/superset/databases/commands/exceptions.py +++ b/superset/databases/commands/exceptions.py @@ -137,6 +137,31 @@ class DatabaseTestConnectionUnexpectedError(SupersetErrorsException): message = _("Unexpected error occurred, please check your logs for details") +class NoValidatorConfigFoundError(SupersetErrorException): + status = 422 + message = _("no SQL validator is configured") + + +class NoValidatorFoundError(SupersetErrorException): + status = 422 + message = _("No validator found (configured for the engine)") + + +class ValidatorSQLError(SupersetErrorException): + status = 422 + message = _("Was unable to check your query") + + +class ValidatorSQLUnexpectedError(CommandException): + status = 422 + message = _("An unexpected error occurred") + + +class ValidatorSQL400Error(SupersetErrorException): + status = 400 + message = _("Was unable to check your query") + + class DatabaseImportError(ImportFailedError): message = _("Import database failed for an unknown reason") diff --git a/superset/databases/commands/validate_sql.py b/superset/databases/commands/validate_sql.py new file mode 100644 index 000000000000..346d684a0d2c --- /dev/null +++ b/superset/databases/commands/validate_sql.py @@ -0,0 +1,121 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import logging +import re +from typing import Any, Dict, List, Optional, Type + +from flask import current_app +from flask_babel import gettext as __ + +from superset.commands.base import BaseCommand +from superset.databases.commands.exceptions import ( + DatabaseNotFoundError, + NoValidatorConfigFoundError, + NoValidatorFoundError, + ValidatorSQL400Error, + ValidatorSQLError, + ValidatorSQLUnexpectedError, +) +from superset.databases.dao import DatabaseDAO +from superset.errors import ErrorLevel, SupersetError, SupersetErrorType +from superset.models.core import Database +from superset.sql_validators import get_validator_by_name +from superset.sql_validators.base import BaseSQLValidator +from superset.utils import core as utils + +logger = logging.getLogger(__name__) + + +class ValidateSQLCommand(BaseCommand): + def __init__(self, model_id: int, data: Dict[str, Any]): + self._properties = data.copy() + self._model_id = model_id + self._model: Optional[Database] = None + self._validator: Optional[Type[BaseSQLValidator]] = None + + def run(self) -> List[Dict[str, Any]]: + """ + Validates a SQL statement + + :return: A List of SQLValidationAnnotation + :raises: DatabaseNotFoundError, NoValidatorConfigFoundError + NoValidatorFoundError, ValidatorSQLUnexpectedError, ValidatorSQLError + ValidatorSQL400Error + """ + self.validate() + if not self._validator or not self._model: + raise ValidatorSQLUnexpectedError() + sql = self._properties["sql"] + schema = self._properties.get("schema") + try: + timeout = current_app.config["SQLLAB_VALIDATION_TIMEOUT"] + timeout_msg = f"The query exceeded the {timeout} seconds timeout." + with utils.timeout(seconds=timeout, error_message=timeout_msg): + errors = self._validator.validate(sql, schema, self._model) + return [err.to_dict() for err in errors] + except Exception as ex: + logger.exception(ex) + superset_error = SupersetError( + message=__( + "%(validator)s was unable to check your query.\n" + "Please recheck your query.\n" + "Exception: %(ex)s", + validator=self._validator.name, + ex=ex, + ), + error_type=SupersetErrorType.GENERIC_DB_ENGINE_ERROR, + level=ErrorLevel.ERROR, + ) + + # Return as a 400 if the database error message says we got a 4xx error + if re.search(r"([\W]|^)4\d{2}([\W]|$)", str(ex)): + raise ValidatorSQL400Error(superset_error) from ex + raise ValidatorSQLError(superset_error) from ex + + def validate(self) -> None: + # Validate/populate model exists + self._model = DatabaseDAO.find_by_id(self._model_id) + if not self._model: + raise DatabaseNotFoundError() + + spec = self._model.db_engine_spec + validators_by_engine = current_app.config["SQL_VALIDATORS_BY_ENGINE"] + if not validators_by_engine or spec.engine not in validators_by_engine: + raise NoValidatorConfigFoundError( + SupersetError( + message=__( + "no SQL validator is configured for {}".format(spec.engine) + ), + error_type=SupersetErrorType.GENERIC_DB_ENGINE_ERROR, + level=ErrorLevel.ERROR, + ), + ) + validator_name = validators_by_engine[spec.engine] + self._validator = get_validator_by_name(validator_name) + if not self._validator: + raise NoValidatorFoundError( + SupersetError( + message=__( + "No validator named {} found " + "(configured for the {} engine)".format( + validator_name, spec.engine + ) + ), + error_type=SupersetErrorType.GENERIC_DB_ENGINE_ERROR, + level=ErrorLevel.ERROR, + ), + ) diff --git a/superset/databases/schemas.py b/superset/databases/schemas.py index e28b70401705..9378b6d1b8bd 100644 --- a/superset/databases/schemas.py +++ b/superset/databases/schemas.py @@ -541,6 +541,19 @@ class SchemasResponseSchema(Schema): result = fields.List(fields.String(description="A database schema name")) +class ValidateSQLRequest(Schema): + sql = fields.String(required=True, description="SQL statement to validate") + schema = fields.String(required=False, allow_none=True) + template_params = fields.Dict(required=False, allow_none=True) + + +class ValidateSQLResponse(Schema): + line_number = fields.Integer() + start_column = fields.Integer() + end_column = fields.Integer() + message = fields.String() + + class DatabaseRelatedChart(Schema): id = fields.Integer() slice_name = fields.String() diff --git a/superset/views/core.py b/superset/views/core.py index 71338c997a0c..de3d4cb2ea9d 100755 --- a/superset/views/core.py +++ b/superset/views/core.py @@ -2311,6 +2311,12 @@ def validate_sql_json( """Validates that arbitrary sql is acceptable for the given database. Returns a list of error/warning annotations as json. """ + logger.warning( + "%s.validate_sql_json " + "This API endpoint is deprecated and will be removed in version 3.0.0", + self.__class__.__name__, + ) + sql = request.form["sql"] database_id = request.form["database_id"] schema = request.form.get("schema") or None diff --git a/tests/integration_tests/databases/api_tests.py b/tests/integration_tests/databases/api_tests.py index a4ddbfd7113c..de754bd60a3a 100644 --- a/tests/integration_tests/databases/api_tests.py +++ b/tests/integration_tests/databases/api_tests.py @@ -21,6 +21,7 @@ from collections import defaultdict from io import BytesIO from unittest import mock +from unittest.mock import patch, MagicMock from zipfile import is_zipfile, ZipFile from operator import itemgetter @@ -71,6 +72,19 @@ from tests.integration_tests.test_app import app +SQL_VALIDATORS_BY_ENGINE = { + "presto": "PrestoDBSQLValidator", + "postgresql": "PostgreSQLValidator", +} + +PRESTO_SQL_VALIDATORS_BY_ENGINE = { + "presto": "PrestoDBSQLValidator", + "sqlite": "PrestoDBSQLValidator", + "postgresql": "PrestoDBSQLValidator", + "mysql": "PrestoDBSQLValidator", +} + + class TestDatabaseApi(SupersetTestCase): def insert_database( self, @@ -2150,7 +2164,8 @@ def test_validate_parameters_invalid_payload_schema(self): "issue_codes": [ { "code": 1020, - "message": "Issue 1020 - The submitted payload has the incorrect schema.", + "message": "Issue 1020 - The submitted payload" + " has the incorrect schema.", } ], }, @@ -2164,7 +2179,8 @@ def test_validate_parameters_invalid_payload_schema(self): "issue_codes": [ { "code": 1020, - "message": "Issue 1020 - The submitted payload has the incorrect schema.", + "message": "Issue 1020 - The submitted payload " + "has the incorrect schema.", } ], }, @@ -2197,7 +2213,8 @@ def test_validate_parameters_missing_fields(self): assert response == { "errors": [ { - "message": "One or more parameters are missing: database, host, username", + "message": "One or more parameters are missing: database, host," + " username", "error_type": "CONNECTION_MISSING_PARAMETERS_ERROR", "level": "warning", "extra": { @@ -2205,7 +2222,8 @@ def test_validate_parameters_missing_fields(self): "issue_codes": [ { "code": 1018, - "message": "Issue 1018 - One or more parameters needed to configure a database are missing.", + "message": "Issue 1018 - One or more parameters " + "needed to configure a database are missing.", } ], }, @@ -2284,7 +2302,8 @@ def test_validate_parameters_invalid_port(self): }, }, { - "message": "The port must be an integer between 0 and 65535 (inclusive).", + "message": "The port must be an integer between " + "0 and 65535 (inclusive).", "error_type": "CONNECTION_INVALID_PORT_ERROR", "level": "error", "extra": { @@ -2336,7 +2355,8 @@ def test_validate_parameters_invalid_host(self, is_hostname_valid): "issue_codes": [ { "code": 1018, - "message": "Issue 1018 - One or more parameters needed to configure a database are missing.", + "message": "Issue 1018 - One or more parameters" + " needed to configure a database are missing.", } ], }, @@ -2350,7 +2370,8 @@ def test_validate_parameters_invalid_host(self, is_hostname_valid): "issue_codes": [ { "code": 1007, - "message": "Issue 1007 - The hostname provided can't be resolved.", + "message": "Issue 1007 - The hostname " + "provided can't be resolved.", } ], }, @@ -2425,3 +2446,190 @@ def test_get_related_objects(self): assert "charts" in rv.json assert "dashboards" in rv.json assert "sqllab_tab_states" in rv.json + + @patch.dict( + "superset.config.SQL_VALIDATORS_BY_ENGINE", + SQL_VALIDATORS_BY_ENGINE, + clear=True, + ) + def test_validate_sql(self): + """ + Database API: validate SQL success + """ + request_payload = { + "sql": "SELECT * from birth_names", + "schema": None, + "template_params": None, + } + + example_db = get_example_database() + if example_db.backend not in ("presto", "postgresql"): + pytest.skip("Only presto and PG are implemented") + + self.login(username="admin") + uri = f"api/v1/database/{example_db.id}/validate_sql" + rv = self.client.post(uri, json=request_payload) + response = json.loads(rv.data.decode("utf-8")) + self.assertEqual(rv.status_code, 200) + self.assertEqual(response["result"], []) + + @patch.dict( + "superset.config.SQL_VALIDATORS_BY_ENGINE", + SQL_VALIDATORS_BY_ENGINE, + clear=True, + ) + def test_validate_sql_errors(self): + """ + Database API: validate SQL with errors + """ + request_payload = { + "sql": "SELECT col1 froma table1", + "schema": None, + "template_params": None, + } + + example_db = get_example_database() + if example_db.backend not in ("presto", "postgresql"): + pytest.skip("Only presto and PG are implemented") + + self.login(username="admin") + uri = f"api/v1/database/{example_db.id}/validate_sql" + rv = self.client.post(uri, json=request_payload) + response = json.loads(rv.data.decode("utf-8")) + self.assertEqual(rv.status_code, 200) + self.assertEqual( + response["result"], + [ + { + "end_column": None, + "line_number": 1, + "message": 'ERROR: syntax error at or near "table1"', + "start_column": None, + } + ], + ) + + @patch.dict( + "superset.config.SQL_VALIDATORS_BY_ENGINE", + SQL_VALIDATORS_BY_ENGINE, + clear=True, + ) + def test_validate_sql_not_found(self): + """ + Database API: validate SQL database not found + """ + request_payload = { + "sql": "SELECT * from birth_names", + "schema": None, + "template_params": None, + } + self.login(username="admin") + uri = ( + f"api/v1/database/{self.get_nonexistent_numeric_id(Database)}/validate_sql" + ) + rv = self.client.post(uri, json=request_payload) + self.assertEqual(rv.status_code, 404) + + @patch.dict( + "superset.config.SQL_VALIDATORS_BY_ENGINE", + SQL_VALIDATORS_BY_ENGINE, + clear=True, + ) + def test_validate_sql_validation_fails(self): + """ + Database API: validate SQL database payload validation fails + """ + request_payload = { + "sql": None, + "schema": None, + "template_params": None, + } + self.login(username="admin") + uri = ( + f"api/v1/database/{self.get_nonexistent_numeric_id(Database)}/validate_sql" + ) + rv = self.client.post(uri, json=request_payload) + response = json.loads(rv.data.decode("utf-8")) + self.assertEqual(rv.status_code, 400) + self.assertEqual(response, {"message": {"sql": ["Field may not be null."]}}) + + @patch.dict( + "superset.config.SQL_VALIDATORS_BY_ENGINE", + {}, + clear=True, + ) + def test_validate_sql_endpoint_noconfig(self): + """Assert that validate_sql_json errors out when no validators are + configured for any db""" + request_payload = { + "sql": "SELECT col1 from table1", + "schema": None, + "template_params": None, + } + + self.login("admin") + + example_db = get_example_database() + + uri = f"api/v1/database/{example_db.id}/validate_sql" + rv = self.client.post(uri, json=request_payload) + response = json.loads(rv.data.decode("utf-8")) + self.assertEqual(rv.status_code, 422) + self.assertEqual( + response, + { + "errors": [ + { + "message": f"no SQL validator is configured for " + f"{example_db.backend}", + "error_type": "GENERIC_DB_ENGINE_ERROR", + "level": "error", + "extra": { + "issue_codes": [ + { + "code": 1002, + "message": "Issue 1002 - The database returned an " + "unexpected error.", + } + ] + }, + } + ] + }, + ) + + @patch("superset.databases.commands.validate_sql.get_validator_by_name") + @patch.dict( + "superset.config.SQL_VALIDATORS_BY_ENGINE", + PRESTO_SQL_VALIDATORS_BY_ENGINE, + clear=True, + ) + def test_validate_sql_endpoint_failure(self, get_validator_by_name): + """Assert that validate_sql_json errors out when the selected validator + raises an unexpected exception""" + + request_payload = { + "sql": "SELECT * FROM birth_names", + "schema": None, + "template_params": None, + } + + self.login("admin") + + validator = MagicMock() + get_validator_by_name.return_value = validator + validator.validate.side_effect = Exception("Kaboom!") + + self.login("admin") + + example_db = get_example_database() + + uri = f"api/v1/database/{example_db.id}/validate_sql" + rv = self.client.post(uri, json=request_payload) + response = json.loads(rv.data.decode("utf-8")) + + # TODO(bkyryliuk): properly handle hive error + if get_example_database().backend == "hive": + return + self.assertEqual(rv.status_code, 422) + self.assertIn("Kaboom!", response["errors"][0]["message"]) diff --git a/tests/integration_tests/sql_validator_tests.py b/tests/integration_tests/sql_validator_tests.py index b1e661cc2c5b..57f31ba4b750 100644 --- a/tests/integration_tests/sql_validator_tests.py +++ b/tests/integration_tests/sql_validator_tests.py @@ -48,13 +48,16 @@ class TestSqlValidatorEndpoint(SupersetTestCase): def tearDown(self): self.logout() + @patch.dict( + "superset.config.SQL_VALIDATORS_BY_ENGINE", + {}, + clear=True, + ) def test_validate_sql_endpoint_noconfig(self): """Assert that validate_sql_json errors out when no validators are configured for any db""" self.login("admin") - app.config["SQL_VALIDATORS_BY_ENGINE"] = {} - resp = self.validate_sql( "SELECT * FROM birth_names", client_id="1", raise_on_error=False ) @@ -231,6 +234,11 @@ def test_validator_query_error(self, flask_g): self.assertEqual(1, len(errors)) + @patch.dict( + "superset.config.SQL_VALIDATORS_BY_ENGINE", + {}, + clear=True, + ) def test_validate_sql_endpoint(self): self.login("admin") # NB this is effectively an integration test -- when there's a default