diff --git a/superset/queries/dao.py b/superset/queries/dao.py index 2f438bdb369f..c7d59343e858 100644 --- a/superset/queries/dao.py +++ b/superset/queries/dao.py @@ -16,6 +16,7 @@ # under the License. import logging from datetime import datetime +from typing import Any, Dict from superset.dao.base import BaseDAO from superset.extensions import db @@ -48,3 +49,10 @@ def update_saved_query_exec_info(query_id: int) -> None: saved_query.rows = query.rows saved_query.last_run = datetime.now() db.session.commit() + + @staticmethod + def save_metadata(query: Query, payload: Dict[str, Any]) -> None: + # pull relevant data from payload and store in extra_json + columns = payload.get("columns", {}) + db.session.add(query) + query.set_extra_json_key("columns", columns) diff --git a/superset/sql_lab.py b/superset/sql_lab.py index cd3684b1d8ea..2eeb2976b412 100644 --- a/superset/sql_lab.py +++ b/superset/sql_lab.py @@ -187,7 +187,7 @@ def get_sql_results( # pylint: disable=too-many-arguments return handle_query_error(ex, query, session) -def execute_sql_statement( # pylint: disable=too-many-arguments,too-many-locals,too-many-statements +def execute_sql_statement( # pylint: disable=too-many-arguments,too-many-statements sql_statement: str, query: Query, session: Session, diff --git a/superset/sqllab/command.py b/superset/sqllab/command.py index ff50e18eda9e..690a8c4c9ebe 100644 --- a/superset/sqllab/command.py +++ b/superset/sqllab/command.py @@ -34,6 +34,7 @@ QueryIsForbiddenToAccessException, SqlLabException, ) +from superset.sqllab.execution_context_convertor import ExecutionContextConvertor from superset.sqllab.limiting_factor import LimitingFactor if TYPE_CHECKING: @@ -42,6 +43,7 @@ from superset.sqllab.sql_json_executer import SqlJsonExecutor from superset.sqllab.sqllab_execution_context import SqlJsonExecutionContext + logger = logging.getLogger(__name__) CommandResult = Dict[str, Any] @@ -94,11 +96,19 @@ def run( # pylint: disable=too-many-statements,useless-suppression status = SqlJsonExecutionStatus.QUERY_ALREADY_CREATED else: status = self._run_sql_json_exec_from_scratch() + + self._execution_context_convertor.set_payload( + self._execution_context, status + ) + + # save columns into metadata_json + self._query_dao.save_metadata( + self._execution_context.query, self._execution_context_convertor.payload + ) + return { "status": status, - "payload": self._execution_context_convertor.to_payload( - self._execution_context, status - ), + "payload": self._execution_context_convertor.serialize_payload(), } except (SqlLabException, SupersetErrorsException) as ex: raise ex @@ -209,12 +219,3 @@ def validate(self, query: Query) -> None: class SqlQueryRender: def render(self, execution_context: SqlJsonExecutionContext) -> str: raise NotImplementedError() - - -class ExecutionContextConvertor: - def to_payload( - self, - execution_context: SqlJsonExecutionContext, - execution_status: SqlJsonExecutionStatus, - ) -> str: - raise NotImplementedError() diff --git a/superset/sqllab/execution_context_convertor.py b/superset/sqllab/execution_context_convertor.py index 20eabfa7b3f4..f49fbd9a31db 100644 --- a/superset/sqllab/execution_context_convertor.py +++ b/superset/sqllab/execution_context_convertor.py @@ -16,52 +16,53 @@ # under the License. from __future__ import annotations -from typing import TYPE_CHECKING +import logging +from typing import Any, Dict, TYPE_CHECKING import simplejson as json import superset.utils.core as utils -from superset.sqllab.command import ExecutionContextConvertor from superset.sqllab.command_status import SqlJsonExecutionStatus from superset.sqllab.utils import apply_display_max_row_configuration_if_require +logger = logging.getLogger(__name__) + if TYPE_CHECKING: from superset.models.sql_lab import Query from superset.sqllab.sql_json_executer import SqlResults from superset.sqllab.sqllab_execution_context import SqlJsonExecutionContext -class ExecutionContextConvertorImpl(ExecutionContextConvertor): +class ExecutionContextConvertor: _max_row_in_display_configuration: int # pylint: disable=invalid-name + _exc_status: SqlJsonExecutionStatus + payload: Dict[str, Any] def set_max_row_in_display(self, value: int) -> None: self._max_row_in_display_configuration = value # pylint: disable=invalid-name - def to_payload( + def set_payload( self, execution_context: SqlJsonExecutionContext, execution_status: SqlJsonExecutionStatus, - ) -> str: - + ) -> None: + self._exc_status = execution_status if execution_status == SqlJsonExecutionStatus.HAS_RESULTS: - return self._to_payload_results_based( - execution_context.get_execution_result() or {} - ) - return self._to_payload_query_based(execution_context.query) + self.payload = execution_context.get_execution_result() or {} + else: + self.payload = execution_context.query.to_dict() - def _to_payload_results_based(self, execution_result: SqlResults) -> str: - return json.dumps( - apply_display_max_row_configuration_if_require( - execution_result, self._max_row_in_display_configuration - ), - default=utils.pessimistic_json_iso_dttm_ser, - ignore_nan=True, - encoding=None, - ) + def serialize_payload(self) -> str: + if self._exc_status == SqlJsonExecutionStatus.HAS_RESULTS: + return json.dumps( + apply_display_max_row_configuration_if_require( + self.payload, self._max_row_in_display_configuration + ), + default=utils.pessimistic_json_iso_dttm_ser, + ignore_nan=True, + encoding=None, + ) - def _to_payload_query_based( # pylint: disable=no-self-use - self, query: Query - ) -> str: return json.dumps( - {"query": query.to_dict()}, default=utils.json_int_dttm_ser, ignore_nan=True + {"query": self.payload}, default=utils.json_int_dttm_ser, ignore_nan=True ) diff --git a/superset/views/core.py b/superset/views/core.py index fa3437e47b9c..15ff3b1620e9 100755 --- a/superset/views/core.py +++ b/superset/views/core.py @@ -113,7 +113,7 @@ QueryIsForbiddenToAccessException, SqlLabException, ) -from superset.sqllab.execution_context_convertor import ExecutionContextConvertorImpl +from superset.sqllab.execution_context_convertor import ExecutionContextConvertor from superset.sqllab.limiting_factor import LimitingFactor from superset.sqllab.query_render import SqlQueryRenderImpl from superset.sqllab.sql_json_executer import ( @@ -1331,7 +1331,7 @@ def add_slices( # pylint: disable=no-self-use @has_access_api @event_logger.log_this @expose("/testconn", methods=["POST", "GET"]) - def testconn(self) -> FlaskResponse: # pylint: disable=no-self-use + def testconn(self) -> FlaskResponse: """Tests a sqla connection""" logger.warning( "%s.testconn " @@ -2306,7 +2306,7 @@ def stop_query(self) -> FlaskResponse: @event_logger.log_this @expose("/validate_sql_json/", methods=["POST", "GET"]) def validate_sql_json( - # pylint: disable=too-many-locals,no-self-use + # pylint: disable=too-many-locals self, ) -> FlaskResponse: """Validates that arbitrary sql is acceptable for the given database. @@ -2406,7 +2406,7 @@ def _create_sql_json_command( sql_json_executor = Superset._create_sql_json_executor( execution_context, query_dao ) - execution_context_convertor = ExecutionContextConvertorImpl() + execution_context_convertor = ExecutionContextConvertor() execution_context_convertor.set_max_row_in_display( int(config.get("DISPLAY_MAX_ROW")) # type: ignore ) diff --git a/tests/unit_tests/dao/queries_test.py b/tests/unit_tests/dao/queries_test.py new file mode 100644 index 000000000000..8df6d2066aac --- /dev/null +++ b/tests/unit_tests/dao/queries_test.py @@ -0,0 +1,55 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import json +from typing import Iterator + +import pytest +from sqlalchemy.orm.session import Session + + +def test_query_dao_save_metadata(app_context: None, session: Session) -> None: + from superset.models.core import Database + from superset.models.sql_lab import Query + + engine = session.get_bind() + Query.metadata.create_all(engine) # pylint: disable=no-member + + db = Database(database_name="my_database", sqlalchemy_uri="sqlite://") + + query_obj = Query( + client_id="foo", + database=db, + tab_name="test_tab", + sql_editor_id="test_editor_id", + sql="select * from bar", + select_sql="select * from bar", + executed_sql="select * from bar", + limit=100, + select_as_cta=False, + rows=100, + error_message="none", + results_key="abc", + ) + + session.add(db) + session.add(query_obj) + + from superset.queries.dao import QueryDAO + + query = session.query(Query).one() + QueryDAO.save_metadata(query=query, payload={"columns": []}) + assert query.extra.get("columns", None) == []