diff --git a/superset/db_engine_specs.py b/superset/db_engine_specs.py index e2f4f70ec85d..669a50562564 100644 --- a/superset/db_engine_specs.py +++ b/superset/db_engine_specs.py @@ -41,7 +41,7 @@ from tableschema import Table from werkzeug.utils import secure_filename -from superset import app, cache_util, conf, db, utils +from superset import app, cache_util, conf, db, sql_parse, utils from superset.exceptions import SupersetTemplateException from superset.utils import QueryStatus @@ -110,32 +110,19 @@ def apply_limit_to_sql(cls, sql, limit, database): ) return database.compile_sqla_query(qry) elif LimitMethod.FORCE_LIMIT: - sql_without_limit = cls.get_query_without_limit(sql) - return '{sql_without_limit} LIMIT {limit}'.format(**locals()) + parsed_query = sql_parse.SupersetQuery(sql) + sql = parsed_query.get_query_with_new_limit(limit) return sql @classmethod def get_limit_from_sql(cls, sql): - limit_pattern = re.compile(r""" - (?ix) # case insensitive, verbose - \s+ # whitespace - LIMIT\s+(\d+) # LIMIT $ROWS - ;? # optional semi-colon - (\s|;)*$ # remove trailing spaces tabs or semicolons - """) - matches = limit_pattern.findall(sql) - if matches: - return int(matches[0][0]) - - @classmethod - def get_query_without_limit(cls, sql): - return re.sub(r""" - (?ix) # case insensitive, verbose - \s+ # whitespace - LIMIT\s+\d+ # LIMIT $ROWS - ;? # optional semi-colon - (\s|;)*$ # remove trailing spaces tabs or semicolons - """, '', sql) + parsed_query = sql_parse.SupersetQuery(sql) + return parsed_query.limit + + @classmethod + def get_query_with_new_limit(cls, sql, limit): + parsed_query = sql_parse.SupersetQuery(sql) + return parsed_query.get_query_with_new_limit(limit) @staticmethod def csv_to_df(**kwargs): diff --git a/superset/sql_parse.py b/superset/sql_parse.py index ea1c9c38851c..7b5103924c06 100644 --- a/superset/sql_parse.py +++ b/superset/sql_parse.py @@ -20,18 +20,24 @@ def __init__(self, sql_statement): self.sql = sql_statement self._table_names = set() self._alias_names = set() + self._limit = None # TODO: multistatement support logging.info('Parsing with sqlparse statement {}'.format(self.sql)) self._parsed = sqlparse.parse(self.sql) for statement in self._parsed: self.__extract_from_token(statement) + self._limit = self._extract_limit_from_query(statement) self._table_names = self._table_names - self._alias_names @property def tables(self): return self._table_names + @property + def limit(self): + return self._limit + def is_select(self): return self._parsed[0].get_type() == 'SELECT' @@ -128,3 +134,41 @@ def __extract_from_token(self, token): for token in item.tokens: if self.__is_identifier(token): self.__process_identifier(token) + + def _get_limit_from_token(self, token): + if token.ttype == sqlparse.tokens.Literal.Number.Integer: + return int(token.value) + elif token.is_group: + return int(token.get_token_at_offset(1).value) + + def _extract_limit_from_query(self, statement): + limit_token = None + for pos, item in enumerate(statement.tokens): + if item.ttype in Keyword and item.value.lower() == 'limit': + limit_token = statement.tokens[pos + 2] + return self._get_limit_from_token(limit_token) + + def get_query_with_new_limit(self, new_limit): + """returns the query with the specified limit""" + """does not change the underlying query""" + if not self._limit: + return self.sql + ' LIMIT ' + str(new_limit) + limit_pos = None + tokens = self._parsed[0].tokens + # Add all items to before_str until there is a limit + for pos, item in enumerate(tokens): + if item.ttype in Keyword and item.value.lower() == 'limit': + limit_pos = pos + break + limit = tokens[limit_pos + 2] + if limit.ttype == sqlparse.tokens.Literal.Number.Integer: + tokens[limit_pos + 2].value = new_limit + elif limit.is_group: + tokens[limit_pos + 2].value = ( + '{}, {}'.format(next(limit.get_identifiers()), new_limit) + ) + + str_res = '' + for i in tokens: + str_res += str(i.value) + return str_res diff --git a/tests/db_engine_specs_test.py b/tests/db_engine_specs_test.py index 447914ed5f84..c85e23a26c02 100644 --- a/tests/db_engine_specs_test.py +++ b/tests/db_engine_specs_test.py @@ -4,8 +4,6 @@ from __future__ import print_function from __future__ import unicode_literals -import textwrap - from superset.db_engine_specs import ( BaseEngineSpec, HiveEngineSpec, MssqlEngineSpec, MySQLEngineSpec, PrestoEngineSpec, @@ -143,18 +141,6 @@ def test_modify_limit_query(self): 'SELECT * FROM a LIMIT 1000', ) - def test_modify_newline_query(self): - self.sql_limit_regex( - 'SELECT * FROM a\nLIMIT 9999', - 'SELECT * FROM a LIMIT 1000', - ) - - def test_modify_lcase_limit_query(self): - self.sql_limit_regex( - 'SELECT * FROM a\tlimit 9999', - 'SELECT * FROM a LIMIT 1000', - ) - def test_limit_query_with_limit_subquery(self): self.sql_limit_regex( 'SELECT * FROM (SELECT * FROM a LIMIT 10) LIMIT 9999', @@ -163,37 +149,38 @@ def test_limit_query_with_limit_subquery(self): def test_limit_with_expr(self): self.sql_limit_regex( - textwrap.dedent("""\ - SELECT - 'LIMIT 777' AS a - , b - FROM - table - LIMIT - 99990"""), - textwrap.dedent("""\ + """ + SELECT + 'LIMIT 777' AS a + , b + FROM + table + LIMIT 99990""", + """ SELECT 'LIMIT 777' AS a , b FROM - table LIMIT 1000"""), + table + LIMIT 1000""", ) def test_limit_expr_and_semicolon(self): self.sql_limit_regex( - textwrap.dedent("""\ + """ SELECT 'LIMIT 777' AS a , b FROM table - LIMIT 99990 ;"""), - textwrap.dedent("""\ + LIMIT 99990 ;""", + """ SELECT 'LIMIT 777' AS a , b FROM - table LIMIT 1000"""), + table + LIMIT 1000 ;""", ) def test_get_datatype(self): @@ -201,3 +188,51 @@ def test_get_datatype(self): self.assertEquals('TINY', MySQLEngineSpec.get_datatype(1)) self.assertEquals('VARCHAR', MySQLEngineSpec.get_datatype(15)) self.assertEquals('VARCHAR', BaseEngineSpec.get_datatype('VARCHAR')) + + def test_limit_with_implicit_offset(self): + self.sql_limit_regex( + """ + SELECT + 'LIMIT 777' AS a + , b + FROM + table + LIMIT 99990, 999999""", + """ + SELECT + 'LIMIT 777' AS a + , b + FROM + table + LIMIT 99990, 1000""", + ) + + def test_limit_with_explicit_offset(self): + self.sql_limit_regex( + """ + SELECT + 'LIMIT 777' AS a + , b + FROM + table + LIMIT 99990 + OFFSET 999999""", + """ + SELECT + 'LIMIT 777' AS a + , b + FROM + table + LIMIT 1000 + OFFSET 999999""", + ) + + def test_limit_with_non_token_limit(self): + self.sql_limit_regex( + """ + SELECT + 'LIMIT 777'""", + """ + SELECT + 'LIMIT 777' LIMIT 1000""", + )