diff --git a/.editorconfig b/.editorconfig new file mode 100644 index 00000000..6fa8b7bb --- /dev/null +++ b/.editorconfig @@ -0,0 +1,23 @@ +# http://editorconfig.org + +root = true + +[*] +indent_style = space +indent_size = 4 +end_of_line = crlf +charset = utf-8 +insert_final_newline = true +trim_trailing_whitespace = true + +[*.{py,ini,yaml,yml,rst}] +indent_style = space +indent_size = 4 +continuation_indent_size = 4 +trim_trailing_whitespace = true + +[{Makefile,*.bat}] +indent_style = tab + +[*.md] +trim_trailing_whitespace = false diff --git a/.gitignore b/.gitignore index 6dde1c36..438de5ff 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,6 @@ +# PyCharm +.idea/ + *.pyc docs/build dist diff --git a/AUTHORS b/AUTHORS index 0f34f069..9831fa11 100644 --- a/AUTHORS +++ b/AUTHORS @@ -33,6 +33,7 @@ Alphabetical list of contributors: * spigwitmer * Tim Graham * Victor Hahn +* Victor Uriarte * vthriller * wayne.wuw * Yago Riveiro diff --git a/CHANGES b/CHANGELOG similarity index 100% rename from CHANGES rename to CHANGELOG diff --git a/COPYING b/LICENSE similarity index 100% rename from COPYING rename to LICENSE diff --git a/examples/column_defs_lowlevel.py b/examples/column_defs_lowlevel.py index 9e945d4e..e804bb26 100644 --- a/examples/column_defs_lowlevel.py +++ b/examples/column_defs_lowlevel.py @@ -15,7 +15,7 @@ parsed = sqlparse.parse(SQL)[0] # extract the parenthesis which holds column definitions -par = parsed.token_next_by_instance(0, sqlparse.sql.Parenthesis) +par = parsed.token_next_by(i=sqlparse.sql.Parenthesis) def extract_definitions(token_list): diff --git a/sqlparse/compat.py b/sqlparse/compat.py index 6b263844..334883b1 100644 --- a/sqlparse/compat.py +++ b/sqlparse/compat.py @@ -14,29 +14,40 @@ PY3 = sys.version_info[0] == 3 if PY3: + def u(s): + return str(s) + + + range = range text_type = str string_types = (str,) from io import StringIO - def u(s): - return str(s) elif PY2: + def u(s, encoding=None): + encoding = encoding or 'unicode-escape' + try: + return unicode(s) + except UnicodeDecodeError: + return unicode(s, encoding) + + + range = xrange text_type = unicode string_types = (basestring,) - from StringIO import StringIO # flake8: noqa - - def u(s): - return unicode(s) + from StringIO import StringIO # Directly copied from six: def with_metaclass(meta, *bases): """Create a base class with a metaclass.""" + # This requires a bit of explanation: the basic idea is to make a dummy # metaclass for one level of class instantiation that replaces itself with # the actual metaclass. class metaclass(meta): def __new__(cls, name, this_bases, d): return meta(name, bases, d) + return type.__new__(metaclass, 'temporary_class', (), {}) diff --git a/sqlparse/engine/grouping.py b/sqlparse/engine/grouping.py index 982488bd..e30abab8 100644 --- a/sqlparse/engine/grouping.py +++ b/sqlparse/engine/grouping.py @@ -1,450 +1,268 @@ # -*- coding: utf-8 -*- -import itertools - from sqlparse import sql from sqlparse import tokens as T +from sqlparse.utils import recurse, imt, find_matching + +M_ROLE = (T.Keyword, ('null', 'role')) +M_SEMICOLON = (T.Punctuation, ';') +M_COMMA = (T.Punctuation, ',') + +T_NUMERICAL = (T.Number, T.Number.Integer, T.Number.Float) +T_STRING = (T.String, T.String.Single, T.String.Symbol) +T_NAME = (T.Name, T.Name.Placeholder) -def _group_left_right(tlist, ttype, value, cls, - check_right=lambda t: True, - check_left=lambda t: True, - include_semicolon=False): - [_group_left_right(sgroup, ttype, value, cls, check_right, check_left, - include_semicolon) for sgroup in tlist.get_sublists() - if not isinstance(sgroup, cls)] - idx = 0 - token = tlist.token_next_match(idx, ttype, value) +def _group_left_right(tlist, m, cls, + valid_left=lambda t: t is not None, + valid_right=lambda t: t is not None, + semicolon=False): + """Groups together tokens that are joined by a middle token. ie. x < y""" + [_group_left_right(sgroup, m, cls, valid_left, valid_right, semicolon) + for sgroup in tlist.get_sublists() if not isinstance(sgroup, cls)] + + token = tlist.token_next_by(m=m) while token: - right = tlist.token_next(tlist.token_index(token)) - left = tlist.token_prev(tlist.token_index(token)) - if right is None or not check_right(right): - token = tlist.token_next_match(tlist.token_index(token) + 1, - ttype, value) - elif left is None or not check_left(left): - token = tlist.token_next_match(tlist.token_index(token) + 1, - ttype, value) - else: - if include_semicolon: - sright = tlist.token_next_match(tlist.token_index(right), - T.Punctuation, ';') - if sright is not None: - # only overwrite "right" if a semicolon is actually - # present. - right = sright - tokens = tlist.tokens_between(left, right)[1:] - if not isinstance(left, cls): - new = cls([left]) - new_idx = tlist.token_index(left) - tlist.tokens.remove(left) - tlist.tokens.insert(new_idx, new) - left = new - left.tokens.extend(tokens) - for t in tokens: - tlist.tokens.remove(t) - token = tlist.token_next_match(tlist.token_index(left) + 1, - ttype, value) - - -def _find_matching(idx, tlist, start_ttype, start_value, end_ttype, end_value): - depth = 1 - for tok in tlist.tokens[idx:]: - if tok.match(start_ttype, start_value): - depth += 1 - elif tok.match(end_ttype, end_value): - depth -= 1 - if depth == 1: - return tok - return None - - -def _group_matching(tlist, start_ttype, start_value, end_ttype, end_value, - cls, include_semicolon=False, recurse=False): - - [_group_matching(sgroup, start_ttype, start_value, end_ttype, end_value, - cls, include_semicolon) for sgroup in tlist.get_sublists() - if recurse] - if isinstance(tlist, cls): - idx = 1 - else: - idx = 0 - token = tlist.token_next_match(idx, start_ttype, start_value) + left, right = tlist.token_prev(token), tlist.token_next(token) + + if valid_left(left) and valid_right(right): + if semicolon: + sright = tlist.token_next_by(m=M_SEMICOLON, idx=right) + right = sright or right # only overwrite if a semicolon present. + tokens = tlist.tokens_between(left, right) + token = tlist.group_tokens(cls, tokens, extend=True) + token = tlist.token_next_by(m=m, idx=token) + + +def _group_matching(tlist, cls): + """Groups Tokens that have beginning and end. ie. parenthesis, brackets..""" + idx = 1 if imt(tlist, i=cls) else 0 + + token = tlist.token_next_by(m=cls.M_OPEN, idx=idx) while token: - tidx = tlist.token_index(token) - end = _find_matching(tidx, tlist, start_ttype, start_value, - end_ttype, end_value) - if end is None: - idx = tidx + 1 - else: - if include_semicolon: - next_ = tlist.token_next(tlist.token_index(end)) - if next_ and next_.match(T.Punctuation, ';'): - end = next_ - group = tlist.group_tokens(cls, tlist.tokens_between(token, end)) - _group_matching(group, start_ttype, start_value, - end_ttype, end_value, cls, include_semicolon) - idx = tlist.token_index(group) + 1 - token = tlist.token_next_match(idx, start_ttype, start_value) + end = find_matching(tlist, token, cls.M_OPEN, cls.M_CLOSE) + if end is not None: + token = tlist.group_tokens(cls, tlist.tokens_between(token, end)) + _group_matching(token, cls) + token = tlist.token_next_by(m=cls.M_OPEN, idx=token) def group_if(tlist): - _group_matching(tlist, T.Keyword, 'IF', T.Keyword, 'END IF', sql.If, True) + _group_matching(tlist, sql.If) def group_for(tlist): - _group_matching(tlist, T.Keyword, 'FOR', T.Keyword, 'END LOOP', - sql.For, True) + _group_matching(tlist, sql.For) def group_foreach(tlist): - _group_matching(tlist, T.Keyword, 'FOREACH', T.Keyword, 'END LOOP', - sql.For, True) + _group_matching(tlist, sql.For) def group_begin(tlist): - _group_matching(tlist, T.Keyword, 'BEGIN', T.Keyword, 'END', - sql.Begin, True) + _group_matching(tlist, sql.Begin) def group_as(tlist): - - def _right_valid(token): - # Currently limited to DML/DDL. Maybe additional more non SQL reserved - # keywords should appear here (see issue8). - return token.ttype not in (T.DML, T.DDL) - - def _left_valid(token): - if token.ttype is T.Keyword and token.value in ('NULL',): - return True - return token.ttype is not T.Keyword - - _group_left_right(tlist, T.Keyword, 'AS', sql.Identifier, - check_right=_right_valid, - check_left=_left_valid) + lfunc = lambda tk: not imt(tk, t=T.Keyword) or tk.value == 'NULL' + rfunc = lambda tk: not imt(tk, t=(T.DML, T.DDL)) + _group_left_right(tlist, (T.Keyword, 'AS'), sql.Identifier, + valid_left=lfunc, valid_right=rfunc) def group_assignment(tlist): - _group_left_right(tlist, T.Assignment, ':=', sql.Assignment, - include_semicolon=True) + _group_left_right(tlist, (T.Assignment, ':='), sql.Assignment, + semicolon=True) def group_comparison(tlist): + I_COMPERABLE = (sql.Parenthesis, sql.Function, sql.Identifier) + T_COMPERABLE = T_NUMERICAL + T_STRING + T_NAME - def _parts_valid(token): - return (token.ttype in (T.String.Symbol, T.String.Single, - T.Name, T.Number, T.Number.Float, - T.Number.Integer, T.Literal, - T.Literal.Number.Integer, T.Name.Placeholder) - or isinstance(token, (sql.Identifier, sql.Parenthesis, - sql.Function)) - or (token.ttype is T.Keyword - and token.value.upper() in ['NULL', ])) - _group_left_right(tlist, T.Operator.Comparison, None, sql.Comparison, - check_left=_parts_valid, check_right=_parts_valid) + func = lambda tk: imt(tk, t=T_COMPERABLE, i=I_COMPERABLE) or ( + imt(tk, t=T.Keyword) and tk.value.upper() == 'NULL') + + _group_left_right(tlist, (T.Operator.Comparison, None), sql.Comparison, + valid_left=func, valid_right=func) def group_case(tlist): - _group_matching(tlist, T.Keyword, 'CASE', T.Keyword, 'END', sql.Case, - include_semicolon=True, recurse=True) + _group_matching(tlist, sql.Case) +@recurse(sql.Identifier) def group_identifier(tlist): - def _consume_cycle(tl, i): - # TODO: Usage of Wildcard token is ambivalent here. - x = itertools.cycle(( - lambda y: (y.match(T.Punctuation, '.') - or y.ttype in (T.Operator, - T.Wildcard, - T.Name) - or isinstance(y, sql.SquareBrackets)), - lambda y: (y.ttype in (T.String.Symbol, - T.Name, - T.Wildcard, - T.Literal.String.Single, - T.Literal.Number.Integer, - T.Literal.Number.Float) - or isinstance(y, (sql.Parenthesis, - sql.SquareBrackets, - sql.Function))))) - for t in tl.tokens[i:]: - # Don't take whitespaces into account. - if t.ttype is T.Whitespace: - yield t - continue - if next(x)(t): - yield t - else: - if isinstance(t, sql.Comment) and t.is_multiline(): - yield t - if t.ttype is T.Keyword.Order: - yield t - return - - def _next_token(tl, i): - # chooses the next token. if two tokens are found then the - # first is returned. - t1 = tl.token_next_by_type( - i, (T.String.Symbol, T.Name, T.Literal.Number.Integer, - T.Literal.Number.Float)) - - i1 = tl.token_index(t1, start=i) if t1 else None - t2_end = None if i1 is None else i1 + 1 - t2 = tl.token_next_by_instance(i, (sql.Function, sql.Parenthesis), - end=t2_end) - - if t1 and t2: - i2 = tl.token_index(t2, start=i) - if i1 > i2: - return t2 - else: - return t1 - elif t1: - return t1 - else: - return t2 + T_IDENT = (T.String.Symbol, T.Name) - # bottom up approach: group subgroups first - [group_identifier(sgroup) for sgroup in tlist.get_sublists() - if not isinstance(sgroup, sql.Identifier)] - - # real processing - idx = 0 - token = _next_token(tlist, idx) + token = tlist.token_next_by(t=T_IDENT) while token: - identifier_tokens = [token] + list( - _consume_cycle(tlist, - tlist.token_index(token, start=idx) + 1)) - # remove trailing whitespace - if identifier_tokens and identifier_tokens[-1].ttype is T.Whitespace: - identifier_tokens = identifier_tokens[:-1] - if not (len(identifier_tokens) == 1 - and (isinstance(identifier_tokens[0], (sql.Function, - sql.Parenthesis)) - or identifier_tokens[0].ttype in ( - T.Literal.Number.Integer, T.Literal.Number.Float))): - group = tlist.group_tokens(sql.Identifier, identifier_tokens) - idx = tlist.token_index(group, start=idx) + 1 - else: - idx += 1 - token = _next_token(tlist, idx) + token = tlist.group_tokens(sql.Identifier, [token, ]) + token = tlist.token_next_by(t=T_IDENT, idx=token) -def group_identifier_list(tlist): - [group_identifier_list(sgroup) for sgroup in tlist.get_sublists() - if not isinstance(sgroup, sql.IdentifierList)] - # Allowed list items - fend1_funcs = [lambda t: isinstance(t, (sql.Identifier, sql.Function, - sql.Case)), - lambda t: t.is_whitespace(), - lambda t: t.ttype == T.Name, - lambda t: t.ttype == T.Wildcard, - lambda t: t.match(T.Keyword, 'null'), - lambda t: t.match(T.Keyword, 'role'), - lambda t: t.ttype == T.Number.Integer, - lambda t: t.ttype == T.String.Single, - lambda t: t.ttype == T.Name.Placeholder, - lambda t: t.ttype == T.Keyword, - lambda t: isinstance(t, sql.Comparison), - lambda t: isinstance(t, sql.Comment), - lambda t: t.ttype == T.Comment.Multiline, - ] - tcomma = tlist.token_next_match(0, T.Punctuation, ',') - start = None - while tcomma is not None: - # Go back one idx to make sure to find the correct tcomma - idx = tlist.token_index(tcomma) - before = tlist.token_prev(idx) - after = tlist.token_next(idx) - # Check if the tokens around tcomma belong to a list - bpassed = apassed = False - for func in fend1_funcs: - if before is not None and func(before): - bpassed = True - if after is not None and func(after): - apassed = True - if not bpassed or not apassed: - # Something's wrong here, skip ahead to next "," - start = None - tcomma = tlist.token_next_match(idx + 1, - T.Punctuation, ',') - else: - if start is None: - start = before - after_idx = tlist.token_index(after, start=idx) - next_ = tlist.token_next(after_idx) - if next_ is None or not next_.match(T.Punctuation, ','): - # Reached the end of the list - tokens = tlist.tokens_between(start, after) - group = tlist.group_tokens(sql.IdentifierList, tokens) - start = None - tcomma = tlist.token_next_match(tlist.token_index(group) + 1, - T.Punctuation, ',') - else: - tcomma = next_ +def group_period(tlist): + lfunc = lambda tk: imt(tk, i=(sql.SquareBrackets, sql.Identifier), + t=(T.Name, T.String.Symbol,)) + rfunc = lambda tk: imt(tk, i=(sql.SquareBrackets, sql.Function), + t=(T.Name, T.String.Symbol, T.Wildcard)) -def group_brackets(tlist): - """Group parentheses () or square brackets [] + _group_left_right(tlist, (T.Punctuation, '.'), sql.Identifier, + valid_left=lfunc, valid_right=rfunc) - This is just like _group_matching, but complicated by the fact that - round brackets can contain square bracket groups and vice versa - """ - if isinstance(tlist, (sql.Parenthesis, sql.SquareBrackets)): - idx = 1 - else: - idx = 0 +def group_arrays(tlist): + token = tlist.token_next_by(i=sql.SquareBrackets) + while token: + prev = tlist.token_prev(idx=token) + if imt(prev, i=(sql.SquareBrackets, sql.Identifier, sql.Function), + t=(T.Name, T.String.Symbol,)): + tokens = tlist.tokens_between(prev, token) + token = tlist.group_tokens(sql.Identifier, tokens, extend=True) + token = tlist.token_next_by(i=sql.SquareBrackets, idx=token) + + +@recurse(sql.Identifier) +def group_operator(tlist): + I_CYCLE = (sql.SquareBrackets, sql.Parenthesis, sql.Function, + sql.Identifier,) # sql.Operation) + # wilcards wouldn't have operations next to them + T_CYCLE = T_NUMERICAL + T_STRING + T_NAME # + T.Wildcard + func = lambda tk: imt(tk, i=I_CYCLE, t=T_CYCLE) + + token = tlist.token_next_by(t=(T.Operator, T.Wildcard)) + while token: + left, right = tlist.token_prev(token), tlist.token_next(token) + + if func(left) and func(right): + token.ttype = T.Operator + tokens = tlist.tokens_between(left, right) + # token = tlist.group_tokens(sql.Operation, tokens) + token = tlist.group_tokens(sql.Identifier, tokens) - # Find the first opening bracket - token = tlist.token_next_match(idx, T.Punctuation, ['(', '[']) + token = tlist.token_next_by(t=(T.Operator, T.Wildcard), idx=token) + + +@recurse(sql.IdentifierList) +def group_identifier_list(tlist): + I_IDENT_LIST = (sql.Function, sql.Case, sql.Identifier, sql.Comparison, + sql.IdentifierList) # sql.Operation + T_IDENT_LIST = (T_NUMERICAL + T_STRING + T_NAME + + (T.Keyword, T.Comment, T.Wildcard)) + + func = lambda t: imt(t, i=I_IDENT_LIST, m=M_ROLE, t=T_IDENT_LIST) + token = tlist.token_next_by(m=M_COMMA) while token: - start_val = token.value # either '(' or '[' - if start_val == '(': - end_val = ')' - group_class = sql.Parenthesis - else: - end_val = ']' - group_class = sql.SquareBrackets + before, after = tlist.token_prev(token), tlist.token_next(token) - tidx = tlist.token_index(token) + if func(before) and func(after): + tokens = tlist.tokens_between(before, after) + token = tlist.group_tokens(sql.IdentifierList, tokens, extend=True) + token = tlist.token_next_by(m=M_COMMA, idx=token) - # Find the corresponding closing bracket - end = _find_matching(tidx, tlist, T.Punctuation, start_val, - T.Punctuation, end_val) - if end is None: - idx = tidx + 1 - else: - group = tlist.group_tokens(group_class, - tlist.tokens_between(token, end)) +def group_brackets(tlist): + _group_matching(tlist, sql.SquareBrackets) - # Check for nested bracket groups within this group - group_brackets(group) - idx = tlist.token_index(group) + 1 - # Find the next opening bracket - token = tlist.token_next_match(idx, T.Punctuation, ['(', '[']) +def group_parenthesis(tlist): + _group_matching(tlist, sql.Parenthesis) +@recurse(sql.Comment) def group_comments(tlist): - [group_comments(sgroup) for sgroup in tlist.get_sublists() - if not isinstance(sgroup, sql.Comment)] - idx = 0 - token = tlist.token_next_by_type(idx, T.Comment) + token = tlist.token_next_by(t=T.Comment) while token: - tidx = tlist.token_index(token) - end = tlist.token_not_matching(tidx + 1, - [lambda t: t.ttype in T.Comment, - lambda t: t.is_whitespace()]) - if end is None: - idx = tidx + 1 - else: - eidx = tlist.token_index(end) - grp_tokens = tlist.tokens_between(token, - tlist.token_prev(eidx, False)) - group = tlist.group_tokens(sql.Comment, grp_tokens) - idx = tlist.token_index(group) - token = tlist.token_next_by_type(idx, T.Comment) + end = tlist.token_not_matching( + token, lambda tk: imt(tk, t=T.Comment) or tk.is_whitespace()) + if end is not None: + end = tlist.token_prev(end, False) + tokens = tlist.tokens_between(token, end) + token = tlist.group_tokens(sql.Comment, tokens) + + token = tlist.token_next_by(t=T.Comment, idx=token) +@recurse(sql.Where) def group_where(tlist): - [group_where(sgroup) for sgroup in tlist.get_sublists() - if not isinstance(sgroup, sql.Where)] - idx = 0 - token = tlist.token_next_match(idx, T.Keyword, 'WHERE') - stopwords = ('ORDER', 'GROUP', 'LIMIT', 'UNION', 'EXCEPT', 'HAVING') + token = tlist.token_next_by(m=sql.Where.M_OPEN) while token: - tidx = tlist.token_index(token) - end = tlist.token_next_match(tidx + 1, T.Keyword, stopwords) + end = tlist.token_next_by(m=sql.Where.M_CLOSE, idx=token) + if end is None: - end = tlist._groupable_tokens[-1] + tokens = tlist.tokens_between(token, tlist._groupable_tokens[-1]) else: - end = tlist.tokens[tlist.token_index(end) - 1] - group = tlist.group_tokens(sql.Where, - tlist.tokens_between(token, end), - ignore_ws=True) - idx = tlist.token_index(group) - token = tlist.token_next_match(idx, T.Keyword, 'WHERE') + tokens = tlist.tokens_between( + token, tlist.tokens[tlist.token_index(end) - 1]) + + token = tlist.group_tokens(sql.Where, tokens) + token = tlist.token_next_by(m=sql.Where.M_OPEN, idx=token) +@recurse() def group_aliased(tlist): - clss = (sql.Identifier, sql.Function, sql.Case) - [group_aliased(sgroup) for sgroup in tlist.get_sublists() - if not isinstance(sgroup, clss)] - idx = 0 - token = tlist.token_next_by_instance(idx, clss) + I_ALIAS = (sql.Parenthesis, sql.Function, sql.Case, sql.Identifier, + ) # sql.Operation) + + token = tlist.token_next_by(i=I_ALIAS, t=T.Number) while token: - next_ = tlist.token_next(tlist.token_index(token)) - if next_ is not None and isinstance(next_, clss): - if not next_.value.upper().startswith('VARCHAR'): - grp = tlist.tokens_between(token, next_)[1:] - token.tokens.extend(grp) - for t in grp: - tlist.tokens.remove(t) - idx = tlist.token_index(token) + 1 - token = tlist.token_next_by_instance(idx, clss) + next_ = tlist.token_next(token) + if imt(next_, i=sql.Identifier): + tokens = tlist.tokens_between(token, next_) + token = tlist.group_tokens(sql.Identifier, tokens, extend=True) + token = tlist.token_next_by(i=I_ALIAS, t=T.Number, idx=token) def group_typecasts(tlist): - _group_left_right(tlist, T.Punctuation, '::', sql.Identifier) + _group_left_right(tlist, (T.Punctuation, '::'), sql.Identifier) +@recurse(sql.Function) def group_functions(tlist): - [group_functions(sgroup) for sgroup in tlist.get_sublists() - if not isinstance(sgroup, sql.Function)] - idx = 0 - token = tlist.token_next_by_type(idx, T.Name) + token = tlist.token_next_by(t=T.Name) while token: next_ = tlist.token_next(token) - if not isinstance(next_, sql.Parenthesis): - idx = tlist.token_index(token) + 1 - else: - func = tlist.group_tokens(sql.Function, - tlist.tokens_between(token, next_)) - idx = tlist.token_index(func) + 1 - token = tlist.token_next_by_type(idx, T.Name) + if imt(next_, i=sql.Parenthesis): + tokens = tlist.tokens_between(token, next_) + token = tlist.group_tokens(sql.Function, tokens) + token = tlist.token_next_by(t=T.Name, idx=token) def group_order(tlist): - idx = 0 - token = tlist.token_next_by_type(idx, T.Keyword.Order) + """Group together Identifier and Asc/Desc token""" + token = tlist.token_next_by(t=T.Keyword.Order) while token: prev = tlist.token_prev(token) - if isinstance(prev, sql.Identifier): - ido = tlist.group_tokens(sql.Identifier, - tlist.tokens_between(prev, token)) - idx = tlist.token_index(ido) + 1 - else: - idx = tlist.token_index(token) + 1 - token = tlist.token_next_by_type(idx, T.Keyword.Order) + if imt(prev, i=sql.Identifier, t=T.Number): + tokens = tlist.tokens_between(prev, token) + token = tlist.group_tokens(sql.Identifier, tokens) + token = tlist.token_next_by(t=T.Keyword.Order, idx=token) +@recurse() def align_comments(tlist): - [align_comments(sgroup) for sgroup in tlist.get_sublists()] - idx = 0 - token = tlist.token_next_by_instance(idx, sql.Comment) + token = tlist.token_next_by(i=sql.Comment) while token: - before = tlist.token_prev(tlist.token_index(token)) + before = tlist.token_prev(token) if isinstance(before, sql.TokenList): - grp = tlist.tokens_between(before, token)[1:] - before.tokens.extend(grp) - for t in grp: - tlist.tokens.remove(t) - idx = tlist.token_index(before) + 1 - else: - idx = tlist.token_index(token) + 1 - token = tlist.token_next_by_instance(idx, sql.Comment) + tokens = tlist.tokens_between(before, token) + token = tlist.group_tokens(sql.TokenList, tokens, extend=True) + token = tlist.token_next_by(i=sql.Comment, idx=token) def group(tlist): for func in [ group_comments, group_brackets, + group_parenthesis, group_functions, group_where, group_case, + group_period, + group_arrays, group_identifier, + group_operator, group_order, group_typecasts, group_as, diff --git a/sqlparse/filters.py b/sqlparse/filters.py index 68e9b1ab..72f17d0b 100644 --- a/sqlparse/filters.py +++ b/sqlparse/filters.py @@ -200,9 +200,7 @@ class StripCommentsFilter: def _get_next_comment(self, tlist): # TODO(andi) Comment types should be unified, see related issue38 - token = tlist.token_next_by_instance(0, sql.Comment) - if token is None: - token = tlist.token_next_by_type(0, T.Comment) + token = tlist.token_next_by(i=sql.Comment, t=T.Comment) return token def _process(self, tlist): diff --git a/sqlparse/sql.py b/sqlparse/sql.py index f357572c..9afdac37 100644 --- a/sqlparse/sql.py +++ b/sqlparse/sql.py @@ -7,6 +7,7 @@ from sqlparse import tokens as T from sqlparse.compat import string_types, u +from sqlparse.utils import imt, remove_quotes class Token(object): @@ -77,7 +78,7 @@ def match(self, ttype, values, regex=False): if regex: if isinstance(values, string_types): - values = set([values]) + values = {values} if self.ttype is T.Keyword: values = set(re.compile(v, re.IGNORECASE) for v in values) @@ -150,7 +151,7 @@ def __init__(self, tokens=None): if tokens is None: tokens = [] self.tokens = tokens - Token.__init__(self, None, self._to_string()) + super(TokenList, self).__init__(None, self.__str__()) def __unicode__(self): return self._to_string() @@ -184,14 +185,6 @@ def _pprint_tree(self, max_depth=None, depth=0): if (token.is_group() and (max_depth is None or depth < max_depth)): token._pprint_tree(max_depth, depth + 1) - def _remove_quotes(self, val): - """Helper that removes surrounding quotes from strings.""" - if not val: - return val - if val[0] in ('"', '\'') and val[-1] == val[0]: - val = val[1:-1] - return val - def get_token_at_offset(self, offset): """Returns the token that is on position offset.""" idx = 0 @@ -213,12 +206,12 @@ def flatten(self): else: yield token -# def __iter__(self): -# return self -# -# def next(self): -# for token in self.tokens: -# yield token + # def __iter__(self): + # return self + # + # def next(self): + # for token in self.tokens: + # yield token def is_group(self): return True @@ -232,6 +225,27 @@ def get_sublists(self): def _groupable_tokens(self): return self.tokens + def _token_matching(self, funcs, start=0, end=None, reverse=False): + """next token that match functions""" + if start is None: + return None + + if not isinstance(start, int): + start = self.token_index(start) + 1 + + if not isinstance(funcs, (list, tuple)): + funcs = (funcs,) + + if reverse: + iterable = iter(reversed(self.tokens[end:start - 1])) + else: + iterable = self.tokens[start:end] + + for token in iterable: + for func in funcs: + if func(token): + return token + def token_first(self, ignore_whitespace=True, ignore_comments=False): """Returns the first child token. @@ -241,12 +255,13 @@ def token_first(self, ignore_whitespace=True, ignore_comments=False): if *ignore_comments* is ``True`` (default: ``False``), comments are ignored too. """ - for token in self.tokens: - if ignore_whitespace and token.is_whitespace(): - continue - if ignore_comments and isinstance(token, Comment): - continue - return token + funcs = lambda tk: not ((ignore_whitespace and tk.is_whitespace()) or + (ignore_comments and imt(tk, i=Comment))) + return self._token_matching(funcs) + + def token_next_by(self, i=None, m=None, t=None, idx=0, end=None): + funcs = lambda tk: imt(tk, i, m, t) + return self._token_matching(funcs, idx, end) def token_next_by_instance(self, idx, clss, end=None): """Returns the next token matching a class. @@ -256,48 +271,26 @@ def token_next_by_instance(self, idx, clss, end=None): If no matching token can be found ``None`` is returned. """ - if not isinstance(clss, (list, tuple)): - clss = (clss,) - - for token in self.tokens[idx:end]: - if isinstance(token, clss): - return token + funcs = lambda tk: imt(tk, i=clss) + return self._token_matching(funcs, idx, end) def token_next_by_type(self, idx, ttypes): """Returns next matching token by it's token type.""" - if not isinstance(ttypes, (list, tuple)): - ttypes = [ttypes] - - for token in self.tokens[idx:]: - if token.ttype in ttypes: - return token + funcs = lambda tk: imt(tk, t=ttypes) + return self._token_matching(funcs, idx) def token_next_match(self, idx, ttype, value, regex=False): """Returns next token where it's ``match`` method returns ``True``.""" - if not isinstance(idx, int): - idx = self.token_index(idx) - - for n in range(idx, len(self.tokens)): - token = self.tokens[n] - if token.match(ttype, value, regex): - return token + funcs = lambda tk: imt(tk, m=(ttype, value, regex)) + return self._token_matching(funcs, idx) def token_not_matching(self, idx, funcs): - for token in self.tokens[idx:]: - passed = False - for func in funcs: - if func(token): - passed = True - break - - if not passed: - return token + funcs = (funcs,) if not isinstance(funcs, (list, tuple)) else funcs + funcs = [lambda tk: not func(tk) for func in funcs] + return self._token_matching(funcs, idx) def token_matching(self, idx, funcs): - for token in self.tokens[idx:]: - for func in funcs: - if func(token): - return token + return self._token_matching(funcs, idx) def token_prev(self, idx, skip_ws=True): """Returns the previous token relative to *idx*. @@ -305,17 +298,10 @@ def token_prev(self, idx, skip_ws=True): If *skip_ws* is ``True`` (the default) whitespace tokens are ignored. ``None`` is returned if there's no previous token. """ - if idx is None: - return None - - if not isinstance(idx, int): - idx = self.token_index(idx) - - while idx: - idx -= 1 - if self.tokens[idx].is_whitespace() and skip_ws: - continue - return self.tokens[idx] + if isinstance(idx, int): + idx += 1 # alot of code usage current pre-compensates for this + funcs = lambda tk: not (tk.is_whitespace() and skip_ws) + return self._token_matching(funcs, idx, reverse=True) def token_next(self, idx, skip_ws=True): """Returns the next token relative to *idx*. @@ -323,59 +309,56 @@ def token_next(self, idx, skip_ws=True): If *skip_ws* is ``True`` (the default) whitespace tokens are ignored. ``None`` is returned if there's no next token. """ - if idx is None: - return None - - if not isinstance(idx, int): - idx = self.token_index(idx) - - while idx < len(self.tokens) - 1: - idx += 1 - if self.tokens[idx].is_whitespace() and skip_ws: - continue - return self.tokens[idx] + if isinstance(idx, int): + idx += 1 # alot of code usage current pre-compensates for this + funcs = lambda tk: not (tk.is_whitespace() and skip_ws) + return self._token_matching(funcs, idx) def token_index(self, token, start=0): """Return list index of token.""" - if start > 0: - # Performing `index` manually is much faster when starting - # in the middle of the list of tokens and expecting to find - # the token near to the starting index. - for i in range(start, len(self.tokens)): - if self.tokens[i] == token: - return i - return -1 - return self.tokens.index(token) - - def tokens_between(self, start, end, exclude_end=False): + start = self.token_index(start) if not isinstance(start, int) else start + return start + self.tokens[start:].index(token) + + def tokens_between(self, start, end, include_end=True): """Return all tokens between (and including) start and end. - If *exclude_end* is ``True`` (default is ``False``) the end token - is included too. + If *include_end* is ``False`` (default is ``True``) the end token + is excluded. """ - # FIXME(andi): rename exclude_end to inlcude_end - if exclude_end: - offset = 0 - else: - offset = 1 - end_idx = self.token_index(end) + offset start_idx = self.token_index(start) + end_idx = include_end + self.token_index(end) return self.tokens[start_idx:end_idx] - def group_tokens(self, grp_cls, tokens, ignore_ws=False): + def group_tokens(self, grp_cls, tokens, ignore_ws=False, extend=False): """Replace tokens by an instance of *grp_cls*.""" - idx = self.token_index(tokens[0]) if ignore_ws: while tokens and tokens[-1].is_whitespace(): tokens = tokens[:-1] - for t in tokens: - self.tokens.remove(t) - grp = grp_cls(tokens) + + left = tokens[0] + idx = self.token_index(left) + + if extend: + if not isinstance(left, grp_cls): + grp = grp_cls([left]) + self.tokens.remove(left) + self.tokens.insert(idx, grp) + left = grp + left.parent = self + tokens = tokens[1:] + left.tokens.extend(tokens) + left.value = left.__str__() + + else: + left = grp_cls(tokens) + left.parent = self + self.tokens.insert(idx, left) + for token in tokens: - token.parent = grp - grp.parent = self - self.tokens.insert(idx, grp) - return grp + token.parent = left + self.tokens.remove(token) + + return left def insert_before(self, where, token): """Inserts *token* before *where*.""" @@ -397,13 +380,12 @@ def get_alias(self): """Returns the alias for this identifier or ``None``.""" # "name AS alias" - kw = self.token_next_match(0, T.Keyword, 'AS') + kw = self.token_next_by(m=(T.Keyword, 'AS')) if kw is not None: return self._get_first_name(kw, keywords=True) # "name alias" or "complicated column expression alias" - if len(self.tokens) > 2 \ - and self.token_next_by_type(0, T.Whitespace) is not None: + if len(self.tokens) > 2 and self.token_next_by(t=T.Whitespace): return self._get_first_name(reverse=True) return None @@ -440,7 +422,7 @@ def get_parent_name(self): prev_ = self.token_prev(self.token_index(dot)) if prev_ is None: # something must be verry wrong here.. return None - return self._remove_quotes(prev_.value) + return remove_quotes(prev_.value) def _get_first_name(self, idx=None, reverse=False, keywords=False): """Returns the name of the first token with a name""" @@ -457,7 +439,7 @@ def _get_first_name(self, idx=None, reverse=False, keywords=False): for tok in tokens: if tok.ttype in types: - return self._remove_quotes(tok.value) + return remove_quotes(tok.value) elif isinstance(tok, Identifier) or isinstance(tok, Function): return tok.get_name() return None @@ -510,8 +492,6 @@ class Identifier(TokenList): Identifiers may have aliases or typecasts. """ - __slots__ = ('value', 'ttype', 'tokens') - def is_wildcard(self): """Return ``True`` if this identifier contains a wildcard.""" token = self.token_next_by_type(0, T.Wildcard) @@ -546,8 +526,6 @@ def get_array_indices(self): class IdentifierList(TokenList): """A list of :class:`~sqlparse.sql.Identifier`\'s.""" - __slots__ = ('value', 'ttype', 'tokens') - def get_identifiers(self): """Returns the identifiers. @@ -560,7 +538,8 @@ def get_identifiers(self): class Parenthesis(TokenList): """Tokens between parenthesis.""" - __slots__ = ('value', 'ttype', 'tokens') + M_OPEN = (T.Punctuation, '(') + M_CLOSE = (T.Punctuation, ')') @property def _groupable_tokens(self): @@ -569,8 +548,8 @@ def _groupable_tokens(self): class SquareBrackets(TokenList): """Tokens between square brackets""" - - __slots__ = ('value', 'ttype', 'tokens') + M_OPEN = (T.Punctuation, '[') + M_CLOSE = (T.Punctuation, ']') @property def _groupable_tokens(self): @@ -579,22 +558,22 @@ def _groupable_tokens(self): class Assignment(TokenList): """An assignment like 'var := val;'""" - __slots__ = ('value', 'ttype', 'tokens') class If(TokenList): """An 'if' clause with possible 'else if' or 'else' parts.""" - __slots__ = ('value', 'ttype', 'tokens') + M_OPEN = (T.Keyword, 'IF') + M_CLOSE = (T.Keyword, 'END IF') class For(TokenList): """A 'FOR' loop.""" - __slots__ = ('value', 'ttype', 'tokens') + M_OPEN = (T.Keyword, ('FOR', 'FOREACH')) + M_CLOSE = (T.Keyword, 'END LOOP') class Comparison(TokenList): """A comparison used for example in WHERE clauses.""" - __slots__ = ('value', 'ttype', 'tokens') @property def left(self): @@ -607,7 +586,6 @@ def right(self): class Comment(TokenList): """A comment.""" - __slots__ = ('value', 'ttype', 'tokens') def is_multiline(self): return self.tokens and self.tokens[0].ttype == T.Comment.Multiline @@ -615,13 +593,15 @@ def is_multiline(self): class Where(TokenList): """A WHERE clause.""" - __slots__ = ('value', 'ttype', 'tokens') + M_OPEN = (T.Keyword, 'WHERE') + M_CLOSE = (T.Keyword, + ('ORDER', 'GROUP', 'LIMIT', 'UNION', 'EXCEPT', 'HAVING')) class Case(TokenList): """A CASE statement with one or more WHEN and possibly an ELSE part.""" - - __slots__ = ('value', 'ttype', 'tokens') + M_OPEN = (T.Keyword, 'CASE') + M_CLOSE = (T.Keyword, 'END') def get_cases(self): """Returns a list of 2-tuples (condition, value). @@ -671,22 +651,18 @@ def get_cases(self): class Function(TokenList): """A function or procedure call.""" - __slots__ = ('value', 'ttype', 'tokens') - def get_parameters(self): """Return a list of parameters.""" parenthesis = self.tokens[-1] for t in parenthesis.tokens: - if isinstance(t, IdentifierList): + if imt(t, i=IdentifierList): return t.get_identifiers() - elif (isinstance(t, Identifier) or - isinstance(t, Function) or - t.ttype in T.Literal): + elif imt(t, i=(Function, Identifier), t=T.Literal): return [t, ] return [] class Begin(TokenList): """A BEGIN/END block.""" - - __slots__ = ('value', 'ttype', 'tokens') + M_OPEN = (T.Keyword, 'BEGIN') + M_CLOSE = (T.Keyword, 'END') diff --git a/sqlparse/utils.py b/sqlparse/utils.py index 7db9a960..90acb5cf 100644 --- a/sqlparse/utils.py +++ b/sqlparse/utils.py @@ -1,16 +1,13 @@ -''' -Created on 17/05/2012 - -@author: piranna -''' - +import itertools import re -from collections import OrderedDict +from collections import OrderedDict, deque +from contextlib import contextmanager class Cache(OrderedDict): """Cache with LRU algorithm using an OrderedDict as basis """ + def __init__(self, maxsize=100): OrderedDict.__init__(self) @@ -113,3 +110,85 @@ def split_unquoted_newlines(text): else: outputlines[-1] += line return outputlines + + +def remove_quotes(val): + """Helper that removes surrounding quotes from strings.""" + if val is None: + return + if val[0] in ('"', "'") and val[0] == val[-1]: + val = val[1:-1] + return val + + +def recurse(*cls): + """Function decorator to help with recursion + + :param cls: Classes to not recurse over + :return: function + """ + def wrap(f): + def wrapped_f(tlist): + for sgroup in tlist.get_sublists(): + if not isinstance(sgroup, cls): + wrapped_f(sgroup) + f(tlist) + + return wrapped_f + + return wrap + + +def imt(token, i=None, m=None, t=None): + """Aid function to refactor comparisons for Instance, Match and TokenType + Aid fun + :param token: + :param i: Class or Tuple/List of Classes + :param m: Tuple of TokenType & Value. Can be list of Tuple for multiple + :param t: TokenType or Tuple/List of TokenTypes + :return: bool + """ + t = (t,) if t and not isinstance(t, (list, tuple)) else t + m = (m,) if m and not isinstance(m, (list,)) else m + + if token is None: + return False + elif i is not None and isinstance(token, i): + return True + elif m is not None and any((token.match(*x) for x in m)): + return True + elif t is not None and token.ttype in t: + return True + else: + return False + + +def find_matching(tlist, token, M1, M2): + idx = tlist.token_index(token) + depth = 0 + for token in tlist.tokens[idx:]: + if token.match(*M1): + depth += 1 + elif token.match(*M2): + depth -= 1 + if depth == 0: + return token + + +def consume(iterator, n): + """Advance the iterator n-steps ahead. If n is none, consume entirely.""" + deque(itertools.islice(iterator, n), maxlen=0) + + +@contextmanager +def offset(filter_, n=0): + filter_.offset += n + yield + filter_.offset -= n + + +@contextmanager +def indent(filter_, n=1): + filter_.indent += n + yield + filter_.indent -= n diff --git a/tests/test_functions.py b/tests/test_functions.py index 425ab7fa..92078150 100644 --- a/tests/test_functions.py +++ b/tests/test_functions.py @@ -13,6 +13,7 @@ from sqlparse.filters import compact from sqlparse.functions import getcolumns, getlimit, IsType +from tests.utils import FILES_DIR class Test_IncludeStatement(TestCase): @@ -27,7 +28,7 @@ class Test_IncludeStatement(TestCase): def test_includeStatement(self): stream = tokenize(self.sql) - includeStatement = IncludeStatement('tests/files', + includeStatement = IncludeStatement(FILES_DIR, raiseexceptions=True) stream = includeStatement.process(None, stream) stream = compact(stream) diff --git a/tests/test_grouping.py b/tests/test_grouping.py index 7dc12690..daaec9bd 100644 --- a/tests/test_grouping.py +++ b/tests/test_grouping.py @@ -89,9 +89,9 @@ def test_identifier_invalid(self): p = sqlparse.parse('a.')[0] self.assert_(isinstance(p.tokens[0], sql.Identifier)) self.assertEqual(p.tokens[0].has_alias(), False) - self.assertEqual(p.tokens[0].get_name(), None) - self.assertEqual(p.tokens[0].get_real_name(), None) - self.assertEqual(p.tokens[0].get_parent_name(), 'a') + self.assertEqual(p.tokens[0].get_name(), 'a') + self.assertEqual(p.tokens[0].get_real_name(), 'a') + self.assertEqual(p.tokens[0].get_parent_name(), None) def test_identifier_as_invalid(self): # issue8 p = sqlparse.parse('foo as select *')[0]