Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor: sql_json view endpoint: encapsulate ctas parameters #16548

Merged
merged 3 commits into from
Sep 5, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
94 changes: 85 additions & 9 deletions superset/utils/sqllab_execution_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,19 @@
import json
import logging
from dataclasses import dataclass
from typing import Any, cast, Dict, Optional
from typing import Any, cast, Dict, Optional, TYPE_CHECKING

from flask import g

from superset import app, is_feature_enabled
from superset.models.sql_lab import Query
from superset.sql_parse import CtasMethod
from superset.utils import core as utils
from superset.utils.dates import now_as_float
from superset.views.utils import get_cta_schema_name

if TYPE_CHECKING:
from superset.connectors.sqla.models import Database

QueryStatus = utils.QueryStatus
logger = logging.getLogger(__name__)
Expand All @@ -42,17 +48,18 @@ class SqlJsonExecutionContext: # pylint: disable=too-many-instance-attributes
async_flag: bool
limit: int
status: str
select_as_cta: bool
ctas_method: CtasMethod
tmp_table_name: str
client_id: str
client_id_or_short_id: str
sql_editor_id: str
tab_name: str
user_id: Optional[int]
expand_data: bool
create_table_as_select: Optional[CreateTableAsSelect]
database: Optional[Database]

def __init__(self, query_params: Dict[str, Any]):
self.create_table_as_select = None
self.database = None
self._init_from_query_params(query_params)
self.user_id = self._get_user_id()
self.client_id_or_short_id = cast(str, self.client_id or utils.shortid()[:10])
Expand All @@ -65,11 +72,8 @@ def _init_from_query_params(self, query_params: Dict[str, Any]) -> None:
self.async_flag = cast(bool, query_params.get("runAsync"))
self.limit = self._get_limit_param(query_params)
self.status = cast(str, query_params.get("status"))
self.select_as_cta = cast(bool, query_params.get("select_as_cta"))
self.ctas_method = cast(
CtasMethod, query_params.get("ctas_method", CtasMethod.TABLE)
)
self.tmp_table_name = cast(str, query_params.get("tmp_table_name"))
if cast(bool, query_params.get("select_as_cta")):
self.create_table_as_select = CreateTableAsSelect.create_from(query_params)
self.client_id = cast(str, query_params.get("client_id"))
self.sql_editor_id = cast(str, query_params.get("sql_editor_id"))
self.tab_name = cast(str, query_params.get("tab"))
Expand Down Expand Up @@ -109,3 +113,75 @@ def _get_user_id(self) -> Optional[int]: # pylint: disable=R0201

def is_run_asynchronous(self) -> bool:
return self.async_flag

@property
def select_as_cta(self) -> bool:
return self.create_table_as_select is not None

def set_database(self, database: Database) -> None:
self._validate_db(database)
self.database = database
if self.select_as_cta:
schema_name = self._get_ctas_target_schema_name(database)
self.create_table_as_select.target_schema_name = schema_name # type: ignore

def _get_ctas_target_schema_name(self, database: Database) -> Optional[str]:
if database.force_ctas_schema:
return database.force_ctas_schema
return get_cta_schema_name(database, g.user, self.schema, self.sql)

def _validate_db(self, database: Database) -> None:
# TODO validate db.id is equal to self.database_id
pass

def create_query(self) -> Query:
# pylint: disable=C0301
start_time = now_as_float()
if self.select_as_cta:
return Query(
database_id=self.database_id,
sql=self.sql,
schema=self.schema,
select_as_cta=True,
ctas_method=self.create_table_as_select.ctas_method, # type: ignore
start_time=start_time,
tab_name=self.tab_name,
status=self.status,
sql_editor_id=self.sql_editor_id,
tmp_table_name=self.create_table_as_select.target_table_name, # type: ignore
tmp_schema_name=self.create_table_as_select.target_schema_name, # type: ignore
user_id=self.user_id,
client_id=self.client_id_or_short_id,
)
return Query(
database_id=self.database_id,
sql=self.sql,
schema=self.schema,
select_as_cta=False,
start_time=start_time,
tab_name=self.tab_name,
status=self.status,
sql_editor_id=self.sql_editor_id,
user_id=self.user_id,
client_id=self.client_id_or_short_id,
)


class CreateTableAsSelect: # pylint: disable=R0903
ctas_method: CtasMethod
target_schema_name: Optional[str]
target_table_name: str

def __init__(
self, ctas_method: CtasMethod, target_schema_name: str, target_table_name: str
):
self.ctas_method = ctas_method
self.target_schema_name = target_schema_name
self.target_table_name = target_table_name

@staticmethod
def create_from(query_params: Dict[str, Any]) -> CreateTableAsSelect:
ctas_method = query_params.get("ctas_method", CtasMethod.TABLE)
schema = cast(str, query_params.get("schema"))
tmp_table_name = cast(str, query_params.get("tmp_table_name"))
return CreateTableAsSelect(ctas_method, schema, tmp_table_name)
46 changes: 8 additions & 38 deletions superset/views/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,6 @@
check_explore_cache_perms,
check_resource_permissions,
check_slice_perms,
get_cta_schema_name,
get_dashboard_extra_filters,
get_datasource_info,
get_form_data,
Expand Down Expand Up @@ -2567,7 +2566,7 @@ def is_query_handled(cls, query: Optional[Query]) -> bool:
QueryStatus.TIMED_OUT,
]

def sql_json_exec( # pylint: disable=too-many-statements,too-many-locals
def sql_json_exec( # pylint: disable=too-many-statements
self,
execution_context: SqlJsonExecutionContext,
query_params: Dict[str, Any],
Expand All @@ -2580,42 +2579,13 @@ def sql_json_exec( # pylint: disable=too-many-statements,too-many-locals
query = self._get_existing_query(execution_context, session)

if self.is_query_handled(query):
# return the existing query
payload = json.dumps(
{"query": query.to_dict()}, default=utils.json_int_dttm_ser # type: ignore
)
payload = self._convert_query_to_payload(cast(Query, query))
return json_success(payload)

mydb = self._get_the_query_db(execution_context, session)

# Set tmp_schema_name for CTA
# TODO(bkyryliuk): consider parsing, splitting tmp_schema_name from
# tmp_table_name if user enters
# <schema_name>.<table_name>
tmp_schema_name: Optional[str] = execution_context.schema
if execution_context.select_as_cta and mydb.force_ctas_schema:
tmp_schema_name = mydb.force_ctas_schema
elif execution_context.select_as_cta:
tmp_schema_name = get_cta_schema_name(
mydb, g.user, execution_context.schema, execution_context.sql
)

# Save current query
query = Query(
database_id=execution_context.database_id,
sql=execution_context.sql,
schema=execution_context.schema,
select_as_cta=execution_context.select_as_cta,
ctas_method=execution_context.ctas_method,
start_time=now_as_float(),
tab_name=execution_context.tab_name,
status=execution_context.status,
sql_editor_id=execution_context.sql_editor_id,
tmp_table_name=execution_context.tmp_table_name,
tmp_schema_name=tmp_schema_name,
user_id=execution_context.user_id,
client_id=execution_context.client_id_or_short_id,
execution_context.set_database(
self._get_the_query_db(execution_context, session)
)
query = execution_context.create_query()
try:
session.add(query)
session.flush()
Expand Down Expand Up @@ -2684,12 +2654,12 @@ def sql_json_exec( # pylint: disable=too-many-statements,too-many-locals
},
)

# Limit is not applied to the CTA queries if SQLLAB_CTAS_NO_LIMIT flag is set
# to True.
if not (config.get("SQLLAB_CTAS_NO_LIMIT") and execution_context.select_as_cta):
# set LIMIT after template processing
limits = [
mydb.db_engine_spec.get_limit_from_sql(rendered_query),
execution_context.database.db_engine_spec.get_limit_from_sql( # type: ignore
rendered_query
),
execution_context.limit,
]
if limits[0] is None or limits[0] > limits[1]: # type: ignore
Expand Down
7 changes: 5 additions & 2 deletions tests/integration_tests/celery_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ def get_query_by_id(id: int):

@pytest.fixture(autouse=True, scope="module")
def setup_sqllab():

with app.app_context():
yield

Expand Down Expand Up @@ -216,7 +217,8 @@ def test_run_sync_query_cta_no_data(setup_sqllab):
@pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
@pytest.mark.parametrize("ctas_method", [CtasMethod.TABLE, CtasMethod.VIEW])
@mock.patch(
"superset.views.core.get_cta_schema_name", lambda d, u, s, sql: CTAS_SCHEMA_NAME
"superset.utils.sqllab_execution_context.get_cta_schema_name",
lambda d, u, s, sql: CTAS_SCHEMA_NAME,
)
def test_run_sync_query_cta_config(setup_sqllab, ctas_method):
if backend() == "sqlite":
Expand All @@ -243,7 +245,8 @@ def test_run_sync_query_cta_config(setup_sqllab, ctas_method):
@pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
@pytest.mark.parametrize("ctas_method", [CtasMethod.TABLE, CtasMethod.VIEW])
@mock.patch(
"superset.views.core.get_cta_schema_name", lambda d, u, s, sql: CTAS_SCHEMA_NAME
"superset.utils.sqllab_execution_context.get_cta_schema_name",
lambda d, u, s, sql: CTAS_SCHEMA_NAME,
)
def test_run_async_query_cta_config(setup_sqllab, ctas_method):
if backend() == "sqlite":
Expand Down
3 changes: 2 additions & 1 deletion tests/integration_tests/sqllab_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@
QUERY_3 = "SELECT * FROM birth_names LIMIT 10"


@pytest.mark.sqllab
class TestSqlLab(SupersetTestCase):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove that custom mark

"""Testings for Sql Lab"""

Expand Down Expand Up @@ -188,7 +189,7 @@ def test_sql_json_cta_dynamic_db(self, ctas_method):
return

with mock.patch(
"superset.views.core.get_cta_schema_name",
"superset.utils.sqllab_execution_context.get_cta_schema_name",
lambda d, u, s, sql: f"{u.username}_database",
):
old_allow_ctas = examples_db.allow_ctas
Expand Down