Skip to content

Commit

Permalink
fix(mssql): support top syntax for limiting queries (#18746)
Browse files Browse the repository at this point in the history
* SQL-TOP Fix For Database Engines

MSSQL is not supporting LIMIT syntax in SQLs. For limiting the rows, MSSQL having a different keyword TOP. Added fixes for handling the TOP and LIMIT clauses based on the database engines.

* Teradata code for top clause handling removed from teradata.py

Teradata code for top clause handling removed from teradata.py file, since we added generic section in base engine for the same.

* Changes to handle CTE along with TOP in complex SQLs

Added changes to handle TOP command in CTEs, for DB Engines which are not supporting inline CTEs.

* Test cases for TOP unit testing in MSSQL

Added multiple unit test cases for MSSQL top command handling and also along with CTEs

* Corrected the select_keywords name key in basengine

Corrected the select_keywords name key in basengine

* Changes based on as per review.

made the required corrections based on code review to keep good code readability and code cleanliness.

* Review changes to correct lint and typo issues

Made the changes according to the review comments.

* fix linting errors

* fix teradata tests

* add coverage

* lint

* Code cleanliness

Moved the top/limit flag check from sql_lab to core.

* Changed for code cleanliness

Changes for keeping code cleanliness

* Corrected lint issue

Corrected lint issue.

* Code cleanliness

Code cleanliness

Co-authored-by: Ville Brofeldt <ville.v.brofeldt@gmail.com>
  • Loading branch information
sujiplr and villebro committed Feb 21, 2022
1 parent a291537 commit 7e51b20
Show file tree
Hide file tree
Showing 8 changed files with 197 additions and 303 deletions.
75 changes: 75 additions & 0 deletions superset/db_engine_specs/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -300,6 +300,16 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
# If True, then it will allow in subquery ,
# if False it will allow as regular CTE
allows_cte_in_subquery = True
# Whether allow LIMIT clause in the SQL
# If True, then the database engine is allowed for LIMIT clause
# If False, then the database engine is allowed for TOP clause
allow_limit_clause = True
# This set will give keywords for select statements
# to consider for the engines with TOP SQL parsing
select_keywords: Set[str] = {"SELECT"}
# This set will give the keywords for data limit statements
# to consider for the engines with TOP SQL parsing
top_keywords: Set[str] = {"TOP"}

force_column_alias_quotes = False
arraysize = 0
Expand Down Expand Up @@ -649,6 +659,71 @@ def apply_limit_to_sql(

return sql

@classmethod
def apply_top_to_sql(cls, sql: str, limit: int) -> str:
"""
Alters the SQL statement to apply a TOP clause
:param limit: Maximum number of rows to be returned by the query
:param sql: SQL query
:return: SQL query with top clause
"""

cte = None
sql_remainder = None
sql = sql.strip(" \t\n;")
sql_statement = sqlparse.format(sql, strip_comments=True)
query_limit: Optional[int] = sql_parse.extract_top_from_query(
sql_statement, cls.top_keywords
)
if not limit:
final_limit = query_limit
elif int(query_limit or 0) < limit and query_limit is not None:
final_limit = query_limit
else:
final_limit = limit
if not cls.allows_cte_in_subquery:
cte, sql_remainder = sql_parse.get_cte_remainder_query(sql_statement)
if cte:
str_statement = str(sql_remainder)
cte = cte + "\n"
else:
cte = ""
str_statement = str(sql)
str_statement = str_statement.replace("\n", " ").replace("\r", "")

tokens = str_statement.rstrip().split(" ")
tokens = [token for token in tokens if token]
if cls.top_not_in_sql(str_statement):
selects = [
i
for i, word in enumerate(tokens)
if word.upper() in cls.select_keywords
]
first_select = selects[0]
tokens.insert(first_select + 1, "TOP")
tokens.insert(first_select + 2, str(final_limit))

next_is_limit_token = False
new_tokens = []

for token in tokens:
if token in cls.top_keywords:
next_is_limit_token = True
elif next_is_limit_token:
if token.isdigit():
token = str(final_limit)
next_is_limit_token = False
new_tokens.append(token)
sql = " ".join(new_tokens)
return cte + sql

@classmethod
def top_not_in_sql(cls, sql: str) -> bool:
for top_word in cls.top_keywords:
if top_word.upper() in sql.upper():
return False
return True

@classmethod
def get_limit_from_sql(cls, sql: str) -> Optional[int]:
"""
Expand Down
1 change: 1 addition & 0 deletions superset/db_engine_specs/mssql.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ class MssqlEngineSpec(BaseEngineSpec):
limit_method = LimitMethod.WRAP_SQL
max_column_name_length = 128
allows_cte_in_subquery = False
allow_limit_clause = False

_time_grain_expressions = {
None: "{col}",
Expand Down
238 changes: 3 additions & 235 deletions superset/db_engine_specs/teradata.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,216 +15,7 @@
# specific language governing permissions and limitations
# under the License.

from typing import Optional, Set

import sqlparse
from sqlparse.sql import (
Identifier,
IdentifierList,
Parenthesis,
remove_quotes,
Token,
TokenList,
)
from sqlparse.tokens import Keyword, Name, Punctuation, String, Whitespace
from sqlparse.utils import imt

from superset.db_engine_specs.base import BaseEngineSpec, LimitMethod
from superset.sql_parse import Table

PRECEDES_TABLE_NAME = {"FROM", "JOIN", "DESCRIBE", "WITH", "LEFT JOIN", "RIGHT JOIN"}
CTE_PREFIX = "CTE__"
JOIN = " JOIN"


def _extract_limit_from_query_td(statement: TokenList) -> Optional[int]:
td_limit_keywork = {"TOP", "SAMPLE"}
str_statement = str(statement)
str_statement = str_statement.replace("\n", " ").replace("\r", "")
token = str_statement.rstrip().split(" ")
token = [part for part in token if part]
limit = None

for i, _ in enumerate(token):
if token[i].upper() in td_limit_keywork and len(token) - 1 > i:
try:
limit = int(token[i + 1])
except ValueError:
limit = None
break
return limit


class ParsedQueryTeradata:
def __init__(
self, sql_statement: str, strip_comments: bool = False, uri_type: str = "None"
):

if strip_comments:
sql_statement = sqlparse.format(sql_statement, strip_comments=True)

self.sql: str = sql_statement
self._tables: Set[Table] = set()
self._alias_names: Set[str] = set()
self._limit: Optional[int] = None
self.uri_type: str = uri_type

self._parsed = sqlparse.parse(self.stripped())
for statement in self._parsed:
self._limit = _extract_limit_from_query_td(statement)

@property
def tables(self) -> Set[Table]:
if not self._tables:
for statement in self._parsed:
self._extract_from_token(statement)

self._tables = {
table for table in self._tables if str(table) not in self._alias_names
}
return self._tables

def stripped(self) -> str:
return self.sql.strip(" \t\n;")

def _extract_from_token(self, token: Token) -> None:
"""
<Identifier> store a list of subtokens and <IdentifierList> store lists of
subtoken list.
It extracts <IdentifierList> and <Identifier> from :param token: and loops
through all subtokens recursively. It finds table_name_preceding_token and
passes <IdentifierList> and <Identifier> to self._process_tokenlist to populate
self._tables.
:param token: instance of Token or child class, e.g. TokenList, to be processed
"""
if not hasattr(token, "tokens"):
return

table_name_preceding_token = False

for item in token.tokens:
if item.is_group and (
not self._is_identifier(item) or isinstance(item.tokens[0], Parenthesis)
):
self._extract_from_token(item)

if item.ttype in Keyword and (
item.normalized in PRECEDES_TABLE_NAME or item.normalized.endswith(JOIN)
):
table_name_preceding_token = True
continue

if item.ttype in Keyword:
table_name_preceding_token = False
continue
if table_name_preceding_token:
if isinstance(item, Identifier):
self._process_tokenlist(item)
elif isinstance(item, IdentifierList):
for item_list in item.get_identifiers():
if isinstance(item_list, TokenList):
self._process_tokenlist(item_list)
elif isinstance(item, IdentifierList):
if any(not self._is_identifier(ItemList) for ItemList in item.tokens):
self._extract_from_token(item)

@staticmethod
def _get_table(tlist: TokenList) -> Optional[Table]:
"""
Return the table if valid, i.e., conforms to the [[catalog.]schema.]table
construct.
:param tlist: The SQL tokens
:returns: The table if the name conforms
"""

# Strip the alias if present.
idx = len(tlist.tokens)

if tlist.has_alias():
ws_idx, _ = tlist.token_next_by(t=Whitespace)

if ws_idx != -1:
idx = ws_idx

tokens = tlist.tokens[:idx]

odd_token_number = len(tokens) in (1, 3, 5)
qualified_name_parts = all(
imt(token, t=[Name, String]) for token in tokens[::2]
)
dot_separators = all(imt(token, m=(Punctuation, ".")) for token in tokens[1::2])
if odd_token_number and qualified_name_parts and dot_separators:
return Table(*[remove_quotes(token.value) for token in tokens[::-2]])

return None

@staticmethod
def _is_identifier(token: Token) -> bool:
return isinstance(token, (IdentifierList, Identifier))

def _process_tokenlist(self, token_list: TokenList) -> None:
"""
Add table names to table set
:param token_list: TokenList to be processed
"""
# exclude subselects
if "(" not in str(token_list):
table = self._get_table(token_list)
if table and not table.table.startswith(CTE_PREFIX):
self._tables.add(table)
return

# store aliases
if token_list.has_alias():
self._alias_names.add(token_list.get_alias())

# some aliases are not parsed properly
if token_list.tokens[0].ttype == Name:
self._alias_names.add(token_list.tokens[0].value)
self._extract_from_token(token_list)

def set_or_update_query_limit_td(self, new_limit: int) -> str:
td_sel_keywords = {"SELECT", "SEL"}
td_limit_keywords = {"TOP", "SAMPLE"}
statement = self._parsed[0]

if not self._limit:
final_limit = new_limit
elif new_limit < self._limit:
final_limit = new_limit
else:
final_limit = self._limit

str_statement = str(statement)
str_statement = str_statement.replace("\n", " ").replace("\r", "")

tokens = str_statement.rstrip().split(" ")
tokens = [token for token in tokens if token]

if limit_not_in_sql(str_statement, td_limit_keywords):
selects = [i for i, word in enumerate(tokens) if word in td_sel_keywords]
first_select = selects[0]
tokens.insert(first_select + 1, "TOP")
tokens.insert(first_select + 2, str(final_limit))

next_is_limit_token = False
new_tokens = []

for token in tokens:
if token.upper() in td_limit_keywords:
next_is_limit_token = True
elif next_is_limit_token:
if token.isdigit():
token = str(final_limit)
next_is_limit_token = False
new_tokens.append(token)

return " ".join(new_tokens)


class TeradataEngineSpec(BaseEngineSpec):
Expand All @@ -234,6 +25,9 @@ class TeradataEngineSpec(BaseEngineSpec):
engine_name = "Teradata"
limit_method = LimitMethod.WRAP_SQL
max_column_name_length = 30 # since 14.10 this is 128
allow_limit_clause = False
select_keywords = {"SELECT", "SEL"}
top_keywords = {"TOP", "SAMPLE"}

_time_grain_expressions = {
None: "{col}",
Expand All @@ -253,29 +47,3 @@ def epoch_to_dttm(cls) -> str:
"AT 0)) AT 0) + (({col} MOD 86400) * INTERVAL '00:00:01' "
"HOUR TO SECOND) AS TIMESTAMP(0))"
)

@classmethod
def apply_limit_to_sql(
cls, sql: str, limit: int, database: str = "Database", force: bool = False
) -> str:
"""
Alters the SQL statement to apply a TOP clause
The function overwrites similar function in base.py because Teradata doesn't
support LIMIT syntax
:param sql: SQL query
:param limit: Maximum number of rows to be returned by the query
:param database: Database instance
:return: SQL query with limit clause
"""

parsed_query = ParsedQueryTeradata(sql)
sql = parsed_query.set_or_update_query_limit_td(limit)

return sql


def limit_not_in_sql(sql: str, limit_words: Set[str]) -> bool:
for limit_word in limit_words:
if limit_word in sql:
return False
return True
4 changes: 3 additions & 1 deletion superset/models/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -488,7 +488,9 @@ def select_star( # pylint: disable=too-many-arguments
def apply_limit_to_sql(
self, sql: str, limit: int = 1000, force: bool = False
) -> str:
return self.db_engine_spec.apply_limit_to_sql(sql, limit, self, force=force)
if self.db_engine_spec.allow_limit_clause:
return self.db_engine_spec.apply_limit_to_sql(sql, limit, self, force=force)
return self.db_engine_spec.apply_top_to_sql(sql, limit)

def safe_sqlalchemy_uri(self) -> str:
return self.sqlalchemy_uri
Expand Down
3 changes: 2 additions & 1 deletion superset/sql_lab.py
Original file line number Diff line number Diff line change
Expand Up @@ -292,7 +292,8 @@ def apply_limit_if_exists(
) -> str:
if query.limit and increased_limit:
# We are fetching one more than the requested limit in order
# to test whether there are more rows than the limit.
# to test whether there are more rows than the limit. According to the DB
# Engine support it will choose top or limit parse
# Later, the extra row will be dropped before sending
# the results back to the user.
sql = database.apply_limit_to_sql(sql, increased_limit, force=True)
Expand Down

0 comments on commit 7e51b20

Please sign in to comment.