diff --git a/superset/utils/sqllab_execution_context.py b/superset/utils/sqllab_execution_context.py index 2aeb520c30c6..8a3d8de2dcdb 100644 --- a/superset/utils/sqllab_execution_context.py +++ b/superset/utils/sqllab_execution_context.py @@ -19,13 +19,19 @@ import json import logging from dataclasses import dataclass -from typing import Any, cast, Dict, Optional +from typing import Any, cast, Dict, Optional, TYPE_CHECKING from flask import g from superset import app, is_feature_enabled +from superset.models.sql_lab import Query from superset.sql_parse import CtasMethod from superset.utils import core as utils +from superset.utils.dates import now_as_float +from superset.views.utils import get_cta_schema_name + +if TYPE_CHECKING: + from superset.connectors.sqla.models import Database QueryStatus = utils.QueryStatus logger = logging.getLogger(__name__) @@ -42,17 +48,18 @@ class SqlJsonExecutionContext: # pylint: disable=too-many-instance-attributes async_flag: bool limit: int status: str - select_as_cta: bool - ctas_method: CtasMethod - tmp_table_name: str client_id: str client_id_or_short_id: str sql_editor_id: str tab_name: str user_id: Optional[int] expand_data: bool + create_table_as_select: Optional[CreateTableAsSelect] + database: Optional[Database] def __init__(self, query_params: Dict[str, Any]): + self.create_table_as_select = None + self.database = None self._init_from_query_params(query_params) self.user_id = self._get_user_id() self.client_id_or_short_id = cast(str, self.client_id or utils.shortid()[:10]) @@ -65,11 +72,8 @@ def _init_from_query_params(self, query_params: Dict[str, Any]) -> None: self.async_flag = cast(bool, query_params.get("runAsync")) self.limit = self._get_limit_param(query_params) self.status = cast(str, query_params.get("status")) - self.select_as_cta = cast(bool, query_params.get("select_as_cta")) - self.ctas_method = cast( - CtasMethod, query_params.get("ctas_method", CtasMethod.TABLE) - ) - self.tmp_table_name = cast(str, query_params.get("tmp_table_name")) + if cast(bool, query_params.get("select_as_cta")): + self.create_table_as_select = CreateTableAsSelect.create_from(query_params) self.client_id = cast(str, query_params.get("client_id")) self.sql_editor_id = cast(str, query_params.get("sql_editor_id")) self.tab_name = cast(str, query_params.get("tab")) @@ -109,3 +113,75 @@ def _get_user_id(self) -> Optional[int]: # pylint: disable=R0201 def is_run_asynchronous(self) -> bool: return self.async_flag + + @property + def select_as_cta(self) -> bool: + return self.create_table_as_select is not None + + def set_database(self, database: Database) -> None: + self._validate_db(database) + self.database = database + if self.select_as_cta: + schema_name = self._get_ctas_target_schema_name(database) + self.create_table_as_select.target_schema_name = schema_name # type: ignore + + def _get_ctas_target_schema_name(self, database: Database) -> Optional[str]: + if database.force_ctas_schema: + return database.force_ctas_schema + return get_cta_schema_name(database, g.user, self.schema, self.sql) + + def _validate_db(self, database: Database) -> None: + # TODO validate db.id is equal to self.database_id + pass + + def create_query(self) -> Query: + # pylint: disable=C0301 + start_time = now_as_float() + if self.select_as_cta: + return Query( + database_id=self.database_id, + sql=self.sql, + schema=self.schema, + select_as_cta=True, + ctas_method=self.create_table_as_select.ctas_method, # type: ignore + start_time=start_time, + tab_name=self.tab_name, + status=self.status, + sql_editor_id=self.sql_editor_id, + tmp_table_name=self.create_table_as_select.target_table_name, # type: ignore + tmp_schema_name=self.create_table_as_select.target_schema_name, # type: ignore + user_id=self.user_id, + client_id=self.client_id_or_short_id, + ) + return Query( + database_id=self.database_id, + sql=self.sql, + schema=self.schema, + select_as_cta=False, + start_time=start_time, + tab_name=self.tab_name, + status=self.status, + sql_editor_id=self.sql_editor_id, + user_id=self.user_id, + client_id=self.client_id_or_short_id, + ) + + +class CreateTableAsSelect: # pylint: disable=R0903 + ctas_method: CtasMethod + target_schema_name: Optional[str] + target_table_name: str + + def __init__( + self, ctas_method: CtasMethod, target_schema_name: str, target_table_name: str + ): + self.ctas_method = ctas_method + self.target_schema_name = target_schema_name + self.target_table_name = target_table_name + + @staticmethod + def create_from(query_params: Dict[str, Any]) -> CreateTableAsSelect: + ctas_method = query_params.get("ctas_method", CtasMethod.TABLE) + schema = cast(str, query_params.get("schema")) + tmp_table_name = cast(str, query_params.get("tmp_table_name")) + return CreateTableAsSelect(ctas_method, schema, tmp_table_name) diff --git a/superset/views/core.py b/superset/views/core.py index 5753a78a7481..d3c56026cca2 100755 --- a/superset/views/core.py +++ b/superset/views/core.py @@ -136,7 +136,6 @@ check_explore_cache_perms, check_resource_permissions, check_slice_perms, - get_cta_schema_name, get_dashboard_extra_filters, get_datasource_info, get_form_data, @@ -2567,7 +2566,7 @@ def is_query_handled(cls, query: Optional[Query]) -> bool: QueryStatus.TIMED_OUT, ] - def sql_json_exec( # pylint: disable=too-many-statements,too-many-locals + def sql_json_exec( # pylint: disable=too-many-statements self, execution_context: SqlJsonExecutionContext, query_params: Dict[str, Any], @@ -2580,42 +2579,13 @@ def sql_json_exec( # pylint: disable=too-many-statements,too-many-locals query = self._get_existing_query(execution_context, session) if self.is_query_handled(query): - # return the existing query - payload = json.dumps( - {"query": query.to_dict()}, default=utils.json_int_dttm_ser # type: ignore - ) + payload = self._convert_query_to_payload(cast(Query, query)) return json_success(payload) - mydb = self._get_the_query_db(execution_context, session) - - # Set tmp_schema_name for CTA - # TODO(bkyryliuk): consider parsing, splitting tmp_schema_name from - # tmp_table_name if user enters - # . - tmp_schema_name: Optional[str] = execution_context.schema - if execution_context.select_as_cta and mydb.force_ctas_schema: - tmp_schema_name = mydb.force_ctas_schema - elif execution_context.select_as_cta: - tmp_schema_name = get_cta_schema_name( - mydb, g.user, execution_context.schema, execution_context.sql - ) - - # Save current query - query = Query( - database_id=execution_context.database_id, - sql=execution_context.sql, - schema=execution_context.schema, - select_as_cta=execution_context.select_as_cta, - ctas_method=execution_context.ctas_method, - start_time=now_as_float(), - tab_name=execution_context.tab_name, - status=execution_context.status, - sql_editor_id=execution_context.sql_editor_id, - tmp_table_name=execution_context.tmp_table_name, - tmp_schema_name=tmp_schema_name, - user_id=execution_context.user_id, - client_id=execution_context.client_id_or_short_id, + execution_context.set_database( + self._get_the_query_db(execution_context, session) ) + query = execution_context.create_query() try: session.add(query) session.flush() @@ -2684,12 +2654,12 @@ def sql_json_exec( # pylint: disable=too-many-statements,too-many-locals }, ) - # Limit is not applied to the CTA queries if SQLLAB_CTAS_NO_LIMIT flag is set - # to True. if not (config.get("SQLLAB_CTAS_NO_LIMIT") and execution_context.select_as_cta): # set LIMIT after template processing limits = [ - mydb.db_engine_spec.get_limit_from_sql(rendered_query), + execution_context.database.db_engine_spec.get_limit_from_sql( # type: ignore + rendered_query + ), execution_context.limit, ] if limits[0] is None or limits[0] > limits[1]: # type: ignore diff --git a/tests/integration_tests/celery_tests.py b/tests/integration_tests/celery_tests.py index e645ef78a232..eb55c7c924f8 100644 --- a/tests/integration_tests/celery_tests.py +++ b/tests/integration_tests/celery_tests.py @@ -72,6 +72,7 @@ def get_query_by_id(id: int): @pytest.fixture(autouse=True, scope="module") def setup_sqllab(): + with app.app_context(): yield @@ -216,7 +217,8 @@ def test_run_sync_query_cta_no_data(setup_sqllab): @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices") @pytest.mark.parametrize("ctas_method", [CtasMethod.TABLE, CtasMethod.VIEW]) @mock.patch( - "superset.views.core.get_cta_schema_name", lambda d, u, s, sql: CTAS_SCHEMA_NAME + "superset.utils.sqllab_execution_context.get_cta_schema_name", + lambda d, u, s, sql: CTAS_SCHEMA_NAME, ) def test_run_sync_query_cta_config(setup_sqllab, ctas_method): if backend() == "sqlite": @@ -243,7 +245,8 @@ def test_run_sync_query_cta_config(setup_sqllab, ctas_method): @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices") @pytest.mark.parametrize("ctas_method", [CtasMethod.TABLE, CtasMethod.VIEW]) @mock.patch( - "superset.views.core.get_cta_schema_name", lambda d, u, s, sql: CTAS_SCHEMA_NAME + "superset.utils.sqllab_execution_context.get_cta_schema_name", + lambda d, u, s, sql: CTAS_SCHEMA_NAME, ) def test_run_async_query_cta_config(setup_sqllab, ctas_method): if backend() == "sqlite": diff --git a/tests/integration_tests/sqllab_tests.py b/tests/integration_tests/sqllab_tests.py index f5b7859ba975..d712daec4c97 100644 --- a/tests/integration_tests/sqllab_tests.py +++ b/tests/integration_tests/sqllab_tests.py @@ -63,6 +63,7 @@ QUERY_3 = "SELECT * FROM birth_names LIMIT 10" +@pytest.mark.sqllab class TestSqlLab(SupersetTestCase): """Testings for Sql Lab""" @@ -188,7 +189,7 @@ def test_sql_json_cta_dynamic_db(self, ctas_method): return with mock.patch( - "superset.views.core.get_cta_schema_name", + "superset.utils.sqllab_execution_context.get_cta_schema_name", lambda d, u, s, sql: f"{u.username}_database", ): old_allow_ctas = examples_db.allow_ctas