Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(mssql): support top syntax for limiting queries #18746

Merged
merged 16 commits into from
Feb 21, 2022
Merged
75 changes: 75 additions & 0 deletions superset/db_engine_specs/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -300,6 +300,16 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
# If True, then it will allow in subquery ,
# if False it will allow as regular CTE
allows_cte_in_subquery = True
# Whether allow LIMIT clause in the SQL
# If True, then the database engine is allowed for LIMIT clause
# If False, then the database engine is allowed for TOP clause
allow_limit_clause = True
# This set will give keywords for select statements
# to consider for the engines with TOP SQL parsing
select_keywords: Set[str] = {"SELECT"}
# This set will give the keywords for data limit statements
# to consider for the engines with TOP SQL parsing
top_keywords: Set[str] = {"TOP"}

force_column_alias_quotes = False
arraysize = 0
Expand Down Expand Up @@ -649,6 +659,71 @@ def apply_limit_to_sql(

return sql

@classmethod
def apply_top_to_sql(cls, sql: str, limit: int) -> str:
"""
Alters the SQL statement to apply a TOP clause
:param limit: Maximum number of rows to be returned by the query
:param sql: SQL query
:return: SQL query with top clause
"""

cte = None
sql_remainder = None
sql = sql.strip(" \t\n;")
sql_statement = sqlparse.format(sql, strip_comments=True)
query_limit: Optional[int] = sql_parse.extract_top_from_query(
sql_statement, cls.top_keywords
)
if not limit:
final_limit = query_limit
elif int(query_limit or 0) < limit and query_limit is not None:
final_limit = query_limit
else:
final_limit = limit
if not cls.allows_cte_in_subquery:
cte, sql_remainder = sql_parse.get_cte_remainder_query(sql_statement)
if cte:
str_statement = str(sql_remainder)
cte = cte + "\n"
else:
cte = ""
str_statement = str(sql)
str_statement = str_statement.replace("\n", " ").replace("\r", "")

tokens = str_statement.rstrip().split(" ")
tokens = [token for token in tokens if token]
if cls.top_not_in_sql(str_statement):
selects = [
i
for i, word in enumerate(tokens)
if word.upper() in cls.select_keywords
]
first_select = selects[0]
tokens.insert(first_select + 1, "TOP")
tokens.insert(first_select + 2, str(final_limit))

next_is_limit_token = False
new_tokens = []

for token in tokens:
if token in cls.top_keywords:
next_is_limit_token = True
elif next_is_limit_token:
if token.isdigit():
token = str(final_limit)
next_is_limit_token = False
new_tokens.append(token)
sql = " ".join(new_tokens)
return cte + sql

@classmethod
def top_not_in_sql(cls, sql: str) -> bool:
for top_word in cls.top_keywords:
if top_word.upper() in sql.upper():
return False
return True

@classmethod
def get_limit_from_sql(cls, sql: str) -> Optional[int]:
"""
Expand Down
3 changes: 3 additions & 0 deletions superset/db_engine_specs/mssql.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,9 @@ class MssqlEngineSpec(BaseEngineSpec):
limit_method = LimitMethod.WRAP_SQL
max_column_name_length = 128
allows_cte_in_subquery = False
allow_limit_clause = False
select_keywords = {"SELECT"}
top_keywords = {"TOP"}
villebro marked this conversation as resolved.
Show resolved Hide resolved

_time_grain_expressions = {
None: "{col}",
Expand Down
238 changes: 3 additions & 235 deletions superset/db_engine_specs/teradata.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,216 +15,7 @@
# specific language governing permissions and limitations
# under the License.

from typing import Optional, Set

import sqlparse
from sqlparse.sql import (
Identifier,
IdentifierList,
Parenthesis,
remove_quotes,
Token,
TokenList,
)
from sqlparse.tokens import Keyword, Name, Punctuation, String, Whitespace
from sqlparse.utils import imt

from superset.db_engine_specs.base import BaseEngineSpec, LimitMethod
from superset.sql_parse import Table

PRECEDES_TABLE_NAME = {"FROM", "JOIN", "DESCRIBE", "WITH", "LEFT JOIN", "RIGHT JOIN"}
CTE_PREFIX = "CTE__"
JOIN = " JOIN"


def _extract_limit_from_query_td(statement: TokenList) -> Optional[int]:
td_limit_keywork = {"TOP", "SAMPLE"}
str_statement = str(statement)
str_statement = str_statement.replace("\n", " ").replace("\r", "")
token = str_statement.rstrip().split(" ")
token = [part for part in token if part]
limit = None

for i, _ in enumerate(token):
if token[i].upper() in td_limit_keywork and len(token) - 1 > i:
try:
limit = int(token[i + 1])
except ValueError:
limit = None
break
return limit


class ParsedQueryTeradata:
def __init__(
self, sql_statement: str, strip_comments: bool = False, uri_type: str = "None"
):

if strip_comments:
sql_statement = sqlparse.format(sql_statement, strip_comments=True)

self.sql: str = sql_statement
self._tables: Set[Table] = set()
self._alias_names: Set[str] = set()
self._limit: Optional[int] = None
self.uri_type: str = uri_type

self._parsed = sqlparse.parse(self.stripped())
for statement in self._parsed:
self._limit = _extract_limit_from_query_td(statement)

@property
def tables(self) -> Set[Table]:
if not self._tables:
for statement in self._parsed:
self._extract_from_token(statement)

self._tables = {
table for table in self._tables if str(table) not in self._alias_names
}
return self._tables

def stripped(self) -> str:
return self.sql.strip(" \t\n;")

def _extract_from_token(self, token: Token) -> None:
"""
<Identifier> store a list of subtokens and <IdentifierList> store lists of
subtoken list.

It extracts <IdentifierList> and <Identifier> from :param token: and loops
through all subtokens recursively. It finds table_name_preceding_token and
passes <IdentifierList> and <Identifier> to self._process_tokenlist to populate

self._tables.

:param token: instance of Token or child class, e.g. TokenList, to be processed
"""
if not hasattr(token, "tokens"):
return

table_name_preceding_token = False

for item in token.tokens:
if item.is_group and (
not self._is_identifier(item) or isinstance(item.tokens[0], Parenthesis)
):
self._extract_from_token(item)

if item.ttype in Keyword and (
item.normalized in PRECEDES_TABLE_NAME or item.normalized.endswith(JOIN)
):
table_name_preceding_token = True
continue

if item.ttype in Keyword:
table_name_preceding_token = False
continue
if table_name_preceding_token:
if isinstance(item, Identifier):
self._process_tokenlist(item)
elif isinstance(item, IdentifierList):
for item_list in item.get_identifiers():
if isinstance(item_list, TokenList):
self._process_tokenlist(item_list)
elif isinstance(item, IdentifierList):
if any(not self._is_identifier(ItemList) for ItemList in item.tokens):
self._extract_from_token(item)

@staticmethod
def _get_table(tlist: TokenList) -> Optional[Table]:
"""
Return the table if valid, i.e., conforms to the [[catalog.]schema.]table
construct.

:param tlist: The SQL tokens
:returns: The table if the name conforms
"""

# Strip the alias if present.
idx = len(tlist.tokens)

if tlist.has_alias():
ws_idx, _ = tlist.token_next_by(t=Whitespace)

if ws_idx != -1:
idx = ws_idx

tokens = tlist.tokens[:idx]

odd_token_number = len(tokens) in (1, 3, 5)
qualified_name_parts = all(
imt(token, t=[Name, String]) for token in tokens[::2]
)
dot_separators = all(imt(token, m=(Punctuation, ".")) for token in tokens[1::2])
if odd_token_number and qualified_name_parts and dot_separators:
return Table(*[remove_quotes(token.value) for token in tokens[::-2]])

return None

@staticmethod
def _is_identifier(token: Token) -> bool:
return isinstance(token, (IdentifierList, Identifier))

def _process_tokenlist(self, token_list: TokenList) -> None:
"""
Add table names to table set

:param token_list: TokenList to be processed
"""
# exclude subselects
if "(" not in str(token_list):
table = self._get_table(token_list)
if table and not table.table.startswith(CTE_PREFIX):
self._tables.add(table)
return

# store aliases
if token_list.has_alias():
self._alias_names.add(token_list.get_alias())

# some aliases are not parsed properly
if token_list.tokens[0].ttype == Name:
self._alias_names.add(token_list.tokens[0].value)
self._extract_from_token(token_list)

def set_or_update_query_limit_td(self, new_limit: int) -> str:
td_sel_keywords = {"SELECT", "SEL"}
td_limit_keywords = {"TOP", "SAMPLE"}
statement = self._parsed[0]

if not self._limit:
final_limit = new_limit
elif new_limit < self._limit:
final_limit = new_limit
else:
final_limit = self._limit

str_statement = str(statement)
str_statement = str_statement.replace("\n", " ").replace("\r", "")

tokens = str_statement.rstrip().split(" ")
tokens = [token for token in tokens if token]

if limit_not_in_sql(str_statement, td_limit_keywords):
selects = [i for i, word in enumerate(tokens) if word in td_sel_keywords]
first_select = selects[0]
tokens.insert(first_select + 1, "TOP")
tokens.insert(first_select + 2, str(final_limit))

next_is_limit_token = False
new_tokens = []

for token in tokens:
if token.upper() in td_limit_keywords:
next_is_limit_token = True
elif next_is_limit_token:
if token.isdigit():
token = str(final_limit)
next_is_limit_token = False
new_tokens.append(token)

return " ".join(new_tokens)


class TeradataEngineSpec(BaseEngineSpec):
Expand All @@ -234,6 +25,9 @@ class TeradataEngineSpec(BaseEngineSpec):
engine_name = "Teradata"
limit_method = LimitMethod.WRAP_SQL
max_column_name_length = 30 # since 14.10 this is 128
allow_limit_clause = False
select_keywords = {"SELECT", "SEL"}
top_keywords = {"TOP", "SAMPLE"}

_time_grain_expressions = {
None: "{col}",
Expand All @@ -253,29 +47,3 @@ def epoch_to_dttm(cls) -> str:
"AT 0)) AT 0) + (({col} MOD 86400) * INTERVAL '00:00:01' "
"HOUR TO SECOND) AS TIMESTAMP(0))"
)

@classmethod
def apply_limit_to_sql(
cls, sql: str, limit: int, database: str = "Database", force: bool = False
) -> str:
"""
Alters the SQL statement to apply a TOP clause
The function overwrites similar function in base.py because Teradata doesn't
support LIMIT syntax
:param sql: SQL query
:param limit: Maximum number of rows to be returned by the query
:param database: Database instance
:return: SQL query with limit clause
"""

parsed_query = ParsedQueryTeradata(sql)
sql = parsed_query.set_or_update_query_limit_td(limit)

return sql


def limit_not_in_sql(sql: str, limit_words: Set[str]) -> bool:
for limit_word in limit_words:
if limit_word in sql:
return False
return True
3 changes: 3 additions & 0 deletions superset/models/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -490,6 +490,9 @@ def apply_limit_to_sql(
) -> str:
return self.db_engine_spec.apply_limit_to_sql(sql, limit, self, force=force)

def apply_top_to_sql(self, sql: str, limit: int = 1000) -> str:
return self.db_engine_spec.apply_top_to_sql(sql, limit)

def safe_sqlalchemy_uri(self) -> str:
return self.sqlalchemy_uri

Expand Down
8 changes: 6 additions & 2 deletions superset/sql_lab.py
Original file line number Diff line number Diff line change
Expand Up @@ -292,10 +292,14 @@ def apply_limit_if_exists(
) -> str:
if query.limit and increased_limit:
# We are fetching one more than the requested limit in order
# to test whether there are more rows than the limit.
# to test whether there are more rows than the limit. According to the DB
# Engine support it will choose top or limit parse
# Later, the extra row will be dropped before sending
# the results back to the user.
sql = database.apply_limit_to_sql(sql, increased_limit, force=True)
if database.db_engine_spec.allow_limit_clause:
sql = database.apply_limit_to_sql(sql, increased_limit, force=True)
else:
sql = database.apply_top_to_sql(sql, increased_limit)
villebro marked this conversation as resolved.
Show resolved Hide resolved
return sql


Expand Down