Skip to content

Commit

Permalink
fix(mssql): support cte in virtual tables (#18567)
Browse files Browse the repository at this point in the history
* Fix for handling regular CTE queries with MSSQL,#8074

* Moved the get_cte_query function from mssql.py to base.py for using irrespetcive of dbengine

* Fix for handling regular CTE queries with MSSQL,#8074

* Moved the get_cte_query function from mssql.py to base.py for using irrespetcive of dbengine

* Unit test added for the db engine CTE SQL parsing.

Unit test added for the db engine CTE SQL parsing.  Removed additional spaces from the CTE parsing SQL generation.

* implement in sqla model

* lint + cleanup

Co-authored-by: Ville Brofeldt <ville.v.brofeldt@gmail.com>
  • Loading branch information
sujiplr and villebro committed Feb 10, 2022
1 parent 00eb6b1 commit b8aef10
Show file tree
Hide file tree
Showing 6 changed files with 165 additions and 13 deletions.
47 changes: 35 additions & 12 deletions superset/connectors/sqla/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@
get_physical_table_metadata,
get_virtual_table_metadata,
)
from superset.db_engine_specs.base import BaseEngineSpec, TimestampExpression
from superset.db_engine_specs.base import BaseEngineSpec, CTE_ALIAS, TimestampExpression
from superset.exceptions import QueryObjectValidationError
from superset.jinja_context import (
BaseTemplateProcessor,
Expand Down Expand Up @@ -107,6 +107,7 @@

class SqlaQuery(NamedTuple):
applied_template_filters: List[str]
cte: Optional[str]
extra_cache_keys: List[Any]
labels_expected: List[str]
prequeries: List[str]
Expand Down Expand Up @@ -562,6 +563,19 @@ class SqlaTable(Model, BaseDatasource): # pylint: disable=too-many-public-metho
def __repr__(self) -> str:
return self.name

@staticmethod
def _apply_cte(sql: str, cte: Optional[str]) -> str:
"""
Append a CTE before the SELECT statement if defined
:param sql: SELECT statement
:param cte: CTE statement
:return:
"""
if cte:
sql = f"{cte}\n{sql}"
return sql

@property
def db_engine_spec(self) -> Type[BaseEngineSpec]:
return self.database.db_engine_spec
Expand Down Expand Up @@ -743,20 +757,18 @@ def values_for_column(self, column_name: str, limit: int = 10000) -> List[Any]:
cols = {col.column_name: col for col in self.columns}
target_col = cols[column_name]
tp = self.get_template_processor()
tbl, cte = self.get_from_clause(tp)

qry = (
select([target_col.get_sqla_col()])
.select_from(self.get_from_clause(tp))
.distinct()
)
qry = select([target_col.get_sqla_col()]).select_from(tbl).distinct()
if limit:
qry = qry.limit(limit)

if self.fetch_values_predicate:
qry = qry.where(self.get_fetch_values_predicate())

engine = self.database.get_sqla_engine()
sql = "{}".format(qry.compile(engine, compile_kwargs={"literal_binds": True}))
sql = qry.compile(engine, compile_kwargs={"literal_binds": True})
sql = self._apply_cte(sql, cte)
sql = self.mutate_query_from_config(sql)

df = pd.read_sql_query(sql=sql, con=engine)
Expand All @@ -778,6 +790,7 @@ def get_template_processor(self, **kwargs: Any) -> BaseTemplateProcessor:
def get_query_str_extended(self, query_obj: QueryObjectDict) -> QueryStringExtended:
sqlaq = self.get_sqla_query(**query_obj)
sql = self.database.compile_sqla_query(sqlaq.sqla_query)
sql = self._apply_cte(sql, sqlaq.cte)
sql = sqlparse.format(sql, reindent=True)
sql = self.mutate_query_from_config(sql)
return QueryStringExtended(
Expand All @@ -800,13 +813,14 @@ def get_sqla_table(self) -> TableClause:

def get_from_clause(
self, template_processor: Optional[BaseTemplateProcessor] = None
) -> Union[TableClause, Alias]:
) -> Tuple[Union[TableClause, Alias], Optional[str]]:
"""
Return where to select the columns and metrics from. Either a physical table
or a virtual table with it's own subquery.
or a virtual table with it's own subquery. If the FROM is referencing a
CTE, the CTE is returned as the second value in the return tuple.
"""
if not self.is_virtual:
return self.get_sqla_table()
return self.get_sqla_table(), None

from_sql = self.get_rendered_sql(template_processor)
parsed_query = ParsedQuery(from_sql)
Expand All @@ -817,7 +831,15 @@ def get_from_clause(
raise QueryObjectValidationError(
_("Virtual dataset query must be read-only")
)
return TextAsFrom(self.text(from_sql), []).alias(VIRTUAL_TABLE_ALIAS)

cte = self.db_engine_spec.get_cte_query(from_sql)
from_clause = (
table(CTE_ALIAS)
if cte
else TextAsFrom(self.text(from_sql), []).alias(VIRTUAL_TABLE_ALIAS)
)

return from_clause, cte

def get_rendered_sql(
self, template_processor: Optional[BaseTemplateProcessor] = None
Expand Down Expand Up @@ -1224,7 +1246,7 @@ def get_sqla_query( # pylint: disable=too-many-arguments,too-many-locals,too-ma

qry = sa.select(select_exprs)

tbl = self.get_from_clause(template_processor)
tbl, cte = self.get_from_clause(template_processor)

if groupby_all_columns:
qry = qry.group_by(*groupby_all_columns.values())
Expand Down Expand Up @@ -1491,6 +1513,7 @@ def get_sqla_query( # pylint: disable=too-many-arguments,too-many-locals,too-ma

return SqlaQuery(
applied_template_filters=applied_template_filters,
cte=cte,
extra_cache_keys=extra_cache_keys,
labels_expected=labels_expected,
sqla_query=qry,
Expand Down
34 changes: 34 additions & 0 deletions superset/db_engine_specs/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@
from sqlalchemy.sql import quoted_name, text
from sqlalchemy.sql.expression import ColumnClause, Select, TextAsFrom, TextClause
from sqlalchemy.types import TypeEngine
from sqlparse.tokens import CTE
from typing_extensions import TypedDict

from superset import security_manager, sql_parse
Expand All @@ -80,6 +81,9 @@
logger = logging.getLogger()


CTE_ALIAS = "__cte"


class TimeGrain(NamedTuple):
name: str # TODO: redundant field, remove
label: str
Expand Down Expand Up @@ -292,6 +296,11 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
# But for backward compatibility, False by default
allows_hidden_cc_in_orderby = False

# Whether allow CTE as subquery or regular CTE
# If True, then it will allow in subquery ,
# if False it will allow as regular CTE
allows_cte_in_subquery = True

force_column_alias_quotes = False
arraysize = 0
max_column_name_length = 0
Expand Down Expand Up @@ -663,6 +672,31 @@ def set_or_update_query_limit(cls, sql: str, limit: int) -> str:
parsed_query = sql_parse.ParsedQuery(sql)
return parsed_query.set_or_update_query_limit(limit)

@classmethod
def get_cte_query(cls, sql: str) -> Optional[str]:
"""
Convert the input CTE based SQL to the SQL for virtual table conversion
:param sql: SQL query
:return: CTE with the main select query aliased as `__cte`
"""
if not cls.allows_cte_in_subquery:
stmt = sqlparse.parse(sql)[0]

# The first meaningful token for CTE will be with WITH
idx, token = stmt.token_next(-1, skip_ws=True, skip_cm=True)
if not (token and token.ttype == CTE):
return None
idx, token = stmt.token_next(idx)
idx = stmt.token_index(token) + 1

# extract rest of the SQLs after CTE
remainder = "".join(str(token) for token in stmt.tokens[idx:]).strip()
return f"WITH {token.value},\n{CTE_ALIAS} AS (\n{remainder}\n)"

return None

@classmethod
def df_to_sql(
cls,
Expand Down
1 change: 1 addition & 0 deletions superset/db_engine_specs/mssql.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ class MssqlEngineSpec(BaseEngineSpec):
engine_name = "Microsoft SQL Server"
limit_method = LimitMethod.WRAP_SQL
max_column_name_length = 128
allows_cte_in_subquery = False

_time_grain_expressions = {
None: "{col}",
Expand Down
2 changes: 1 addition & 1 deletion tests/integration_tests/core_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -984,7 +984,7 @@ def test_comments_in_sqlatable_query(self):
sql=commented_query,
database=get_example_database(),
)
rendered_query = str(table.get_from_clause())
rendered_query = str(table.get_from_clause()[0])
self.assertEqual(clean_query, rendered_query)

def test_slice_payload_no_datasource(self):
Expand Down
43 changes: 43 additions & 0 deletions tests/unit_tests/db_engine_specs/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,11 @@
# under the License.
# pylint: disable=unused-argument, import-outside-toplevel, protected-access

from textwrap import dedent

import pytest
from flask.ctx import AppContext
from sqlalchemy.types import TypeEngine


def test_get_text_clause_with_colon(app_context: AppContext) -> None:
Expand Down Expand Up @@ -56,3 +60,42 @@ def test_parse_sql_multi_statement(app_context: AppContext) -> None:
"SELECT foo FROM tbl1",
"SELECT bar FROM tbl2",
]


@pytest.mark.parametrize(
"original,expected",
[
(
dedent(
"""
with currency as
(
select 'INR' as cur
)
select * from currency
"""
),
None,
),
("SELECT 1 as cnt", None,),
(
dedent(
"""
select 'INR' as cur
union
select 'AUD' as cur
union
select 'USD' as cur
"""
),
None,
),
],
)
def test_cte_query_parsing(
app_context: AppContext, original: TypeEngine, expected: str
) -> None:
from superset.db_engine_specs.base import BaseEngineSpec

actual = BaseEngineSpec.get_cte_query(original)
assert actual == expected
51 changes: 51 additions & 0 deletions tests/unit_tests/db_engine_specs/test_mssql.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,57 @@ def test_column_datatype_to_string(
assert actual == expected


@pytest.mark.parametrize(
"original,expected",
[
(
dedent(
"""
with currency as (
select 'INR' as cur
),
currency_2 as (
select 'EUR' as cur
)
select * from currency union all select * from currency_2
"""
),
dedent(
"""WITH currency as (
select 'INR' as cur
),
currency_2 as (
select 'EUR' as cur
),
__cte AS (
select * from currency union all select * from currency_2
)"""
),
),
("SELECT 1 as cnt", None,),
(
dedent(
"""
select 'INR' as cur
union
select 'AUD' as cur
union
select 'USD' as cur
"""
),
None,
),
],
)
def test_cte_query_parsing(
app_context: AppContext, original: TypeEngine, expected: str
) -> None:
from superset.db_engine_specs.mssql import MssqlEngineSpec

actual = MssqlEngineSpec.get_cte_query(original)
assert actual == expected


def test_extract_errors(app_context: AppContext) -> None:
"""
Test that custom error messages are extracted correctly.
Expand Down

0 comments on commit b8aef10

Please sign in to comment.