Skip to content

Commit

Permalink
Merge pull request #177 from darikg/brackets
Browse files Browse the repository at this point in the history
Better square bracket / array index handling
  • Loading branch information
andialbrecht committed Mar 5, 2015
2 parents 15b0cb9 + acdebef commit bf26160
Show file tree
Hide file tree
Showing 5 changed files with 125 additions and 52 deletions.
80 changes: 61 additions & 19 deletions sqlparse/engine/grouping.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,19 +51,21 @@ def _group_left_right(tlist, ttype, value, cls,
ttype, value)


def _find_matching(idx, tlist, start_ttype, start_value, end_ttype, end_value):
depth = 1
for tok in tlist.tokens[idx:]:
if tok.match(start_ttype, start_value):
depth += 1
elif tok.match(end_ttype, end_value):
depth -= 1
if depth == 1:
return tok
return None


def _group_matching(tlist, start_ttype, start_value, end_ttype, end_value,
cls, include_semicolon=False, recurse=False):
def _find_matching(i, tl, stt, sva, ett, eva):
depth = 1
for n in xrange(i, len(tl.tokens)):
t = tl.tokens[n]
if t.match(stt, sva):
depth += 1
elif t.match(ett, eva):
depth -= 1
if depth == 1:
return t
return None

[_group_matching(sgroup, start_ttype, start_value, end_ttype, end_value,
cls, include_semicolon) for sgroup in tlist.get_sublists()
if recurse]
Expand Down Expand Up @@ -157,16 +159,17 @@ def _consume_cycle(tl, i):
lambda y: (y.match(T.Punctuation, '.')
or y.ttype in (T.Operator,
T.Wildcard,
T.ArrayIndex,
T.Name)),
T.Name)
or isinstance(y, sql.SquareBrackets)),
lambda y: (y.ttype in (T.String.Symbol,
T.Name,
T.Wildcard,
T.ArrayIndex,
T.Literal.String.Single,
T.Literal.Number.Integer,
T.Literal.Number.Float)
or isinstance(y, (sql.Parenthesis, sql.Function)))))
or isinstance(y, (sql.Parenthesis,
sql.SquareBrackets,
sql.Function)))))
for t in tl.tokens[i:]:
# Don't take whitespaces into account.
if t.ttype is T.Whitespace:
Expand Down Expand Up @@ -275,9 +278,48 @@ def group_identifier_list(tlist):
tcomma = next_


def group_parenthesis(tlist):
_group_matching(tlist, T.Punctuation, '(', T.Punctuation, ')',
sql.Parenthesis)
def group_brackets(tlist):
"""Group parentheses () or square brackets []
This is just like _group_matching, but complicated by the fact that
round brackets can contain square bracket groups and vice versa
"""

if isinstance(tlist, (sql.Parenthesis, sql.SquareBrackets)):
idx = 1
else:
idx = 0

# Find the first opening bracket
token = tlist.token_next_match(idx, T.Punctuation, ['(', '['])

while token:
start_val = token.value # either '(' or '['
if start_val == '(':
end_val = ')'
group_class = sql.Parenthesis
else:
end_val = ']'
group_class = sql.SquareBrackets

tidx = tlist.token_index(token)

# Find the corresponding closing bracket
end = _find_matching(tidx, tlist, T.Punctuation, start_val,
T.Punctuation, end_val)

if end is None:
idx = tidx + 1
else:
group = tlist.group_tokens(group_class,
tlist.tokens_between(token, end))

# Check for nested bracket groups within this group
group_brackets(group)
idx = tlist.token_index(group) + 1

# Find the next opening bracket
token = tlist.token_next_match(idx, T.Punctuation, ['(', '['])


def group_comments(tlist):
Expand Down Expand Up @@ -393,7 +435,7 @@ def align_comments(tlist):
def group(tlist):
for func in [
group_comments,
group_parenthesis,
group_brackets,
group_functions,
group_where,
group_case,
Expand Down
6 changes: 4 additions & 2 deletions sqlparse/lexer.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,8 +194,10 @@ class Lexer(object):
(r"'(''|\\\\|\\'|[^'])*'", tokens.String.Single),
# not a real string literal in ANSI SQL:
(r'(""|".*?[^\\]")', tokens.String.Symbol),
(r'(?<=[\w\]])(\[[^\]]*?\])', tokens.Punctuation.ArrayIndex),
(r'(\[[^\]]+\])', tokens.Name),
# sqlite names can be escaped with [square brackets]. left bracket
# cannot be preceded by word character or a right bracket --
# otherwise it's probably an array index
(r'(?<![\w\])])(\[[^\]]+\])', tokens.Name),
(r'((LEFT\s+|RIGHT\s+|FULL\s+)?(INNER\s+|OUTER\s+|STRAIGHT\s+)?|(CROSS\s+|NATURAL\s+)?)?JOIN\b', tokens.Keyword),
(r'END(\s+IF|\s+LOOP)?\b', tokens.Keyword),
(r'NOT NULL\b', tokens.Keyword),
Expand Down
18 changes: 14 additions & 4 deletions sqlparse/sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -511,11 +511,12 @@ def get_ordering(self):
return ordering.value.upper()

def get_array_indices(self):
"""Returns an iterator of index expressions as strings"""
"""Returns an iterator of index token lists"""

# Use [1:-1] index to discard the square brackets
return (tok.value[1:-1] for tok in self.tokens
if tok.ttype in T.ArrayIndex)
for tok in self.tokens:
if isinstance(tok, SquareBrackets):
# Use [1:-1] index to discard the square brackets
yield tok.tokens[1:-1]


class IdentifierList(TokenList):
Expand All @@ -542,6 +543,15 @@ def _groupable_tokens(self):
return self.tokens[1:-1]


class SquareBrackets(TokenList):
"""Tokens between square brackets"""

__slots__ = ('value', 'ttype', 'tokens')

@property
def _groupable_tokens(self):
return self.tokens[1:-1]

class Assignment(TokenList):
"""An assignment like 'var := val;'"""
__slots__ = ('value', 'ttype', 'tokens')
Expand Down
1 change: 0 additions & 1 deletion sqlparse/tokens.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,6 @@ def __repr__(self):
String = Literal.String
Number = Literal.Number
Punctuation = Token.Punctuation
ArrayIndex = Punctuation.ArrayIndex
Operator = Token.Operator
Comparison = Operator.Comparison
Wildcard = Token.Wildcard
Expand Down
72 changes: 46 additions & 26 deletions tests/test_parse.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,7 +215,7 @@ def test_single_quotes_with_linebreaks(): # issue118
assert p[0].ttype is T.String.Single


def test_array_indexed_column():
def test_sqlite_identifiers():
# Make sure we still parse sqlite style escapes
p = sqlparse.parse('[col1],[col2]')[0].tokens
assert (len(p) == 1
Expand All @@ -227,39 +227,59 @@ def test_array_indexed_column():
types = [tok.ttype for tok in p.flatten()]
assert types == [T.Name, T.Operator, T.Name]


def test_simple_1d_array_index():
p = sqlparse.parse('col[1]')[0].tokens
assert (len(p) == 1
and tuple(p[0].get_array_indices()) == ('1',)
and p[0].get_name() == 'col')
assert len(p) == 1
assert p[0].get_name() == 'col'
indices = list(p[0].get_array_indices())
assert (len(indices) == 1 # 1-dimensional index
and len(indices[0]) == 1 # index is single token
and indices[0][0].value == '1')

p = sqlparse.parse('col[1][1:5] as mycol')[0].tokens
assert (len(p) == 1
and tuple(p[0].get_array_indices()) == ('1', '1:5')
and p[0].get_name() == 'mycol'
and p[0].get_real_name() == 'col')

p = sqlparse.parse('col[1][other_col]')[0].tokens
assert len(p) == 1 and tuple(p[0].get_array_indices()) == ('1', 'other_col')

sql = 'SELECT col1, my_1d_array[2] as alias1, my_2d_array[2][5] as alias2'
p = sqlparse.parse(sql)[0].tokens
assert len(p) == 3 and isinstance(p[2], sqlparse.sql.IdentifierList)
ids = list(p[2].get_identifiers())
assert (ids[0].get_name() == 'col1'
and tuple(ids[0].get_array_indices()) == ()
and ids[1].get_name() == 'alias1'
and ids[1].get_real_name() == 'my_1d_array'
and tuple(ids[1].get_array_indices()) == ('2',)
and ids[2].get_name() == 'alias2'
and ids[2].get_real_name() == 'my_2d_array'
and tuple(ids[2].get_array_indices()) == ('2', '5'))

def test_2d_array_index():
p = sqlparse.parse('col[x][(y+1)*2]')[0].tokens
assert len(p) == 1
assert p[0].get_name() == 'col'
assert len(list(p[0].get_array_indices())) == 2 # 2-dimensional index


def test_array_index_function_result():
p = sqlparse.parse('somefunc()[1]')[0].tokens
assert len(p) == 1
assert len(list(p[0].get_array_indices())) == 1


def test_schema_qualified_array_index():
p = sqlparse.parse('schem.col[1]')[0].tokens
assert len(p) == 1
assert p[0].get_parent_name() == 'schem'
assert p[0].get_name() == 'col'
assert list(p[0].get_array_indices())[0][0].value == '1'


def test_aliased_array_index():
p = sqlparse.parse('col[1] x')[0].tokens
assert len(p) == 1
assert p[0].get_alias() == 'x'
assert p[0].get_real_name() == 'col'
assert list(p[0].get_array_indices())[0][0].value == '1'


def test_array_literal():
# See issue #176
p = sqlparse.parse('ARRAY[%s, %s]')[0]
assert len(p.tokens) == 2
assert len(list(p.flatten())) == 7


def test_typed_array_definition():
# array indices aren't grouped with builtins, but make sure we can extract
# indentifer names
p = sqlparse.parse('x int, y int[], z int')[0]
names = [x.get_name() for x in p.get_sublists()]
names = [x.get_name() for x in p.get_sublists()
if isinstance(x, sqlparse.sql.Identifier)]
assert names == ['x', 'y', 'z']


0 comments on commit bf26160

Please sign in to comment.