Skip to content

Commit

Permalink
refactor: sql_json view endpoint: encapsulate ctas parameters (#16548)
Browse files Browse the repository at this point in the history
* refactor sql_json view endpoint: encapsulate ctas parameters

* fix failed tests

* fix failed tests and ci issues
  • Loading branch information
ofekisr committed Sep 5, 2021
1 parent 359383b commit be77ad2
Show file tree
Hide file tree
Showing 4 changed files with 100 additions and 50 deletions.
94 changes: 85 additions & 9 deletions superset/utils/sqllab_execution_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,19 @@
import json
import logging
from dataclasses import dataclass
from typing import Any, cast, Dict, Optional
from typing import Any, cast, Dict, Optional, TYPE_CHECKING

from flask import g

from superset import app, is_feature_enabled
from superset.models.sql_lab import Query
from superset.sql_parse import CtasMethod
from superset.utils import core as utils
from superset.utils.dates import now_as_float
from superset.views.utils import get_cta_schema_name

if TYPE_CHECKING:
from superset.connectors.sqla.models import Database

QueryStatus = utils.QueryStatus
logger = logging.getLogger(__name__)
Expand All @@ -42,17 +48,18 @@ class SqlJsonExecutionContext: # pylint: disable=too-many-instance-attributes
async_flag: bool
limit: int
status: str
select_as_cta: bool
ctas_method: CtasMethod
tmp_table_name: str
client_id: str
client_id_or_short_id: str
sql_editor_id: str
tab_name: str
user_id: Optional[int]
expand_data: bool
create_table_as_select: Optional[CreateTableAsSelect]
database: Optional[Database]

def __init__(self, query_params: Dict[str, Any]):
self.create_table_as_select = None
self.database = None
self._init_from_query_params(query_params)
self.user_id = self._get_user_id()
self.client_id_or_short_id = cast(str, self.client_id or utils.shortid()[:10])
Expand All @@ -65,11 +72,8 @@ def _init_from_query_params(self, query_params: Dict[str, Any]) -> None:
self.async_flag = cast(bool, query_params.get("runAsync"))
self.limit = self._get_limit_param(query_params)
self.status = cast(str, query_params.get("status"))
self.select_as_cta = cast(bool, query_params.get("select_as_cta"))
self.ctas_method = cast(
CtasMethod, query_params.get("ctas_method", CtasMethod.TABLE)
)
self.tmp_table_name = cast(str, query_params.get("tmp_table_name"))
if cast(bool, query_params.get("select_as_cta")):
self.create_table_as_select = CreateTableAsSelect.create_from(query_params)
self.client_id = cast(str, query_params.get("client_id"))
self.sql_editor_id = cast(str, query_params.get("sql_editor_id"))
self.tab_name = cast(str, query_params.get("tab"))
Expand Down Expand Up @@ -109,3 +113,75 @@ def _get_user_id(self) -> Optional[int]: # pylint: disable=R0201

def is_run_asynchronous(self) -> bool:
return self.async_flag

@property
def select_as_cta(self) -> bool:
return self.create_table_as_select is not None

def set_database(self, database: Database) -> None:
self._validate_db(database)
self.database = database
if self.select_as_cta:
schema_name = self._get_ctas_target_schema_name(database)
self.create_table_as_select.target_schema_name = schema_name # type: ignore

def _get_ctas_target_schema_name(self, database: Database) -> Optional[str]:
if database.force_ctas_schema:
return database.force_ctas_schema
return get_cta_schema_name(database, g.user, self.schema, self.sql)

def _validate_db(self, database: Database) -> None:
# TODO validate db.id is equal to self.database_id
pass

def create_query(self) -> Query:
# pylint: disable=C0301
start_time = now_as_float()
if self.select_as_cta:
return Query(
database_id=self.database_id,
sql=self.sql,
schema=self.schema,
select_as_cta=True,
ctas_method=self.create_table_as_select.ctas_method, # type: ignore
start_time=start_time,
tab_name=self.tab_name,
status=self.status,
sql_editor_id=self.sql_editor_id,
tmp_table_name=self.create_table_as_select.target_table_name, # type: ignore
tmp_schema_name=self.create_table_as_select.target_schema_name, # type: ignore
user_id=self.user_id,
client_id=self.client_id_or_short_id,
)
return Query(
database_id=self.database_id,
sql=self.sql,
schema=self.schema,
select_as_cta=False,
start_time=start_time,
tab_name=self.tab_name,
status=self.status,
sql_editor_id=self.sql_editor_id,
user_id=self.user_id,
client_id=self.client_id_or_short_id,
)


class CreateTableAsSelect: # pylint: disable=R0903
ctas_method: CtasMethod
target_schema_name: Optional[str]
target_table_name: str

def __init__(
self, ctas_method: CtasMethod, target_schema_name: str, target_table_name: str
):
self.ctas_method = ctas_method
self.target_schema_name = target_schema_name
self.target_table_name = target_table_name

@staticmethod
def create_from(query_params: Dict[str, Any]) -> CreateTableAsSelect:
ctas_method = query_params.get("ctas_method", CtasMethod.TABLE)
schema = cast(str, query_params.get("schema"))
tmp_table_name = cast(str, query_params.get("tmp_table_name"))
return CreateTableAsSelect(ctas_method, schema, tmp_table_name)
46 changes: 8 additions & 38 deletions superset/views/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,6 @@
check_explore_cache_perms,
check_resource_permissions,
check_slice_perms,
get_cta_schema_name,
get_dashboard_extra_filters,
get_datasource_info,
get_form_data,
Expand Down Expand Up @@ -2567,7 +2566,7 @@ def is_query_handled(cls, query: Optional[Query]) -> bool:
QueryStatus.TIMED_OUT,
]

def sql_json_exec( # pylint: disable=too-many-statements,too-many-locals
def sql_json_exec( # pylint: disable=too-many-statements
self,
execution_context: SqlJsonExecutionContext,
query_params: Dict[str, Any],
Expand All @@ -2580,42 +2579,13 @@ def sql_json_exec( # pylint: disable=too-many-statements,too-many-locals
query = self._get_existing_query(execution_context, session)

if self.is_query_handled(query):
# return the existing query
payload = json.dumps(
{"query": query.to_dict()}, default=utils.json_int_dttm_ser # type: ignore
)
payload = self._convert_query_to_payload(cast(Query, query))
return json_success(payload)

mydb = self._get_the_query_db(execution_context, session)

# Set tmp_schema_name for CTA
# TODO(bkyryliuk): consider parsing, splitting tmp_schema_name from
# tmp_table_name if user enters
# <schema_name>.<table_name>
tmp_schema_name: Optional[str] = execution_context.schema
if execution_context.select_as_cta and mydb.force_ctas_schema:
tmp_schema_name = mydb.force_ctas_schema
elif execution_context.select_as_cta:
tmp_schema_name = get_cta_schema_name(
mydb, g.user, execution_context.schema, execution_context.sql
)

# Save current query
query = Query(
database_id=execution_context.database_id,
sql=execution_context.sql,
schema=execution_context.schema,
select_as_cta=execution_context.select_as_cta,
ctas_method=execution_context.ctas_method,
start_time=now_as_float(),
tab_name=execution_context.tab_name,
status=execution_context.status,
sql_editor_id=execution_context.sql_editor_id,
tmp_table_name=execution_context.tmp_table_name,
tmp_schema_name=tmp_schema_name,
user_id=execution_context.user_id,
client_id=execution_context.client_id_or_short_id,
execution_context.set_database(
self._get_the_query_db(execution_context, session)
)
query = execution_context.create_query()
try:
session.add(query)
session.flush()
Expand Down Expand Up @@ -2684,12 +2654,12 @@ def sql_json_exec( # pylint: disable=too-many-statements,too-many-locals
},
)

# Limit is not applied to the CTA queries if SQLLAB_CTAS_NO_LIMIT flag is set
# to True.
if not (config.get("SQLLAB_CTAS_NO_LIMIT") and execution_context.select_as_cta):
# set LIMIT after template processing
limits = [
mydb.db_engine_spec.get_limit_from_sql(rendered_query),
execution_context.database.db_engine_spec.get_limit_from_sql( # type: ignore
rendered_query
),
execution_context.limit,
]
if limits[0] is None or limits[0] > limits[1]: # type: ignore
Expand Down
7 changes: 5 additions & 2 deletions tests/integration_tests/celery_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ def get_query_by_id(id: int):

@pytest.fixture(autouse=True, scope="module")
def setup_sqllab():

with app.app_context():
yield

Expand Down Expand Up @@ -216,7 +217,8 @@ def test_run_sync_query_cta_no_data(setup_sqllab):
@pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
@pytest.mark.parametrize("ctas_method", [CtasMethod.TABLE, CtasMethod.VIEW])
@mock.patch(
"superset.views.core.get_cta_schema_name", lambda d, u, s, sql: CTAS_SCHEMA_NAME
"superset.utils.sqllab_execution_context.get_cta_schema_name",
lambda d, u, s, sql: CTAS_SCHEMA_NAME,
)
def test_run_sync_query_cta_config(setup_sqllab, ctas_method):
if backend() == "sqlite":
Expand All @@ -243,7 +245,8 @@ def test_run_sync_query_cta_config(setup_sqllab, ctas_method):
@pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
@pytest.mark.parametrize("ctas_method", [CtasMethod.TABLE, CtasMethod.VIEW])
@mock.patch(
"superset.views.core.get_cta_schema_name", lambda d, u, s, sql: CTAS_SCHEMA_NAME
"superset.utils.sqllab_execution_context.get_cta_schema_name",
lambda d, u, s, sql: CTAS_SCHEMA_NAME,
)
def test_run_async_query_cta_config(setup_sqllab, ctas_method):
if backend() == "sqlite":
Expand Down
3 changes: 2 additions & 1 deletion tests/integration_tests/sqllab_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@
QUERY_3 = "SELECT * FROM birth_names LIMIT 10"


@pytest.mark.sqllab
class TestSqlLab(SupersetTestCase):
"""Testings for Sql Lab"""

Expand Down Expand Up @@ -188,7 +189,7 @@ def test_sql_json_cta_dynamic_db(self, ctas_method):
return

with mock.patch(
"superset.views.core.get_cta_schema_name",
"superset.utils.sqllab_execution_context.get_cta_schema_name",
lambda d, u, s, sql: f"{u.username}_database",
):
old_allow_ctas = examples_db.allow_ctas
Expand Down

0 comments on commit be77ad2

Please sign in to comment.