Skip to content

Commit

Permalink
Merge pull request #171 from darikg/alias_bugfix
Browse files Browse the repository at this point in the history
Fix #167
  • Loading branch information
andialbrecht committed Feb 21, 2015
2 parents 51871a8 + 6f134c6 commit 2d72b7a
Show file tree
Hide file tree
Showing 3 changed files with 92 additions and 42 deletions.
7 changes: 4 additions & 3 deletions sqlparse/engine/grouping.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,9 +155,10 @@ def _consume_cycle(tl, i):
# TODO: Usage of Wildcard token is ambivalent here.
x = itertools.cycle((
lambda y: (y.match(T.Punctuation, '.')
or y.ttype is T.Operator
or y.ttype is T.Wildcard
or y.ttype is T.ArrayIndex),
or y.ttype in (T.Operator,
T.Wildcard,
T.ArrayIndex,
T.Name)),
lambda y: (y.ttype in (T.String.Symbol,
T.Name,
T.Wildcard,
Expand Down
76 changes: 42 additions & 34 deletions sqlparse/sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -390,21 +390,17 @@ def has_alias(self):

def get_alias(self):
"""Returns the alias for this identifier or ``None``."""

# "name AS alias"
kw = self.token_next_match(0, T.Keyword, 'AS')
if kw is not None:
alias = self.token_next(self.token_index(kw))
if alias is None:
return None
else:
next_ = self.token_next_by_instance(0, Identifier)
if next_ is None:
next_ = self.token_next_by_type(0, T.String.Symbol)
if next_ is None:
return None
alias = next_
if isinstance(alias, Identifier):
return alias.get_name()
return self._remove_quotes(unicode(alias))
return self._get_first_name(kw, keywords=True)

# "name alias" or "complicated column expression alias"
if len(self.tokens) > 2:
return self._get_first_name(reverse=True)

return None

def get_name(self):
"""Returns the name of this identifier.
Expand All @@ -422,18 +418,43 @@ def get_real_name(self):
"""Returns the real name (object name) of this identifier."""
# a.b
dot = self.token_next_match(0, T.Punctuation, '.')
if dot is not None:
return self._get_first_name(self.token_index(dot))

return self._get_first_name()

def get_parent_name(self):
"""Return name of the parent object if any.
A parent object is identified by the first occuring dot.
"""
dot = self.token_next_match(0, T.Punctuation, '.')
if dot is None:
next_ = self.token_next_by_type(0, T.Name)
if next_ is not None:
return self._remove_quotes(next_.value)
return None

next_ = self.token_next_by_type(self.token_index(dot),
(T.Name, T.Wildcard, T.String.Symbol))
if next_ is None: # invalid identifier, e.g. "a."
prev_ = self.token_prev(self.token_index(dot))
if prev_ is None: # something must be verry wrong here..
return None
return self._remove_quotes(next_.value)
return self._remove_quotes(prev_.value)

def _get_first_name(self, idx=None, reverse=False, keywords=False):
"""Returns the name of the first token with a name"""

if idx and not isinstance(idx, int):
idx = self.token_index(idx) + 1

tokens = self.tokens[idx:] if idx else self.tokens
tokens = reversed(tokens) if reverse else tokens
types = [T.Name, T.Wildcard, T.String.Symbol]

if keywords:
types.append(T.Keyword)

for tok in tokens:
if tok.ttype in types:
return self._remove_quotes(tok.value)
elif isinstance(tok, Identifier) or isinstance(tok, Function):
return tok.get_name()
return None

class Statement(TokenList):
"""Represents a SQL statement."""
Expand Down Expand Up @@ -467,19 +488,6 @@ class Identifier(TokenList):

__slots__ = ('value', 'ttype', 'tokens')

def get_parent_name(self):
"""Return name of the parent object if any.
A parent object is identified by the first occuring dot.
"""
dot = self.token_next_match(0, T.Punctuation, '.')
if dot is None:
return None
prev_ = self.token_prev(self.token_index(dot))
if prev_ is None: # something must be verry wrong here..
return None
return self._remove_quotes(prev_.value)

def is_wildcard(self):
"""Return ``True`` if this identifier contains a wildcard."""
token = self.token_next_by_type(0, T.Wildcard)
Expand Down
51 changes: 46 additions & 5 deletions tests/test_grouping.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,12 @@ def test_parenthesis(self):
s = 'select (select (x3) x2) and (y2) bar'
parsed = sqlparse.parse(s)[0]
self.ndiffAssertEqual(s, str(parsed))
self.assertEqual(len(parsed.tokens), 9)
self.assertEqual(len(parsed.tokens), 7)
self.assert_(isinstance(parsed.tokens[2], sql.Parenthesis))
self.assert_(isinstance(parsed.tokens[-3], sql.Parenthesis))
self.assertEqual(len(parsed.tokens[2].tokens), 7)
self.assert_(isinstance(parsed.tokens[2].tokens[3], sql.Parenthesis))
self.assert_(isinstance(parsed.tokens[-1], sql.Identifier))
self.assertEqual(len(parsed.tokens[2].tokens), 5)
self.assert_(isinstance(parsed.tokens[2].tokens[3], sql.Identifier))
self.assert_(isinstance(parsed.tokens[2].tokens[3].tokens[0], sql.Parenthesis))
self.assertEqual(len(parsed.tokens[2].tokens[3].tokens), 3)

def test_comments(self):
Expand Down Expand Up @@ -145,7 +146,7 @@ def test_where(self):
s = 'select x from (select y from foo where bar = 1) z'
p = sqlparse.parse(s)[0]
self.ndiffAssertEqual(s, unicode(p))
self.assertTrue(isinstance(p.tokens[-3].tokens[-2], sql.Where))
self.assertTrue(isinstance(p.tokens[-1].tokens[0].tokens[-2], sql.Where))

def test_typecast(self):
s = 'select foo::integer from bar'
Expand Down Expand Up @@ -345,3 +346,43 @@ def test_nested_begin():
assert inner.tokens[0].value == 'BEGIN'
assert inner.tokens[-1].value == 'END'
assert isinstance(inner, sql.Begin)


def test_aliased_column_without_as():
p = sqlparse.parse('foo bar')[0].tokens
assert len(p) == 1
assert p[0].get_real_name() == 'foo'
assert p[0].get_alias() == 'bar'

p = sqlparse.parse('foo.bar baz')[0].tokens[0]
assert p.get_parent_name() == 'foo'
assert p.get_real_name() == 'bar'
assert p.get_alias() == 'baz'


def test_qualified_function():
p = sqlparse.parse('foo()')[0].tokens[0]
assert p.get_parent_name() is None
assert p.get_real_name() == 'foo'

p = sqlparse.parse('foo.bar()')[0].tokens[0]
assert p.get_parent_name() == 'foo'
assert p.get_real_name() == 'bar'


def test_aliased_function_without_as():
p = sqlparse.parse('foo() bar')[0].tokens[0]
assert p.get_parent_name() is None
assert p.get_real_name() == 'foo'
assert p.get_alias() == 'bar'

p = sqlparse.parse('foo.bar() baz')[0].tokens[0]
assert p.get_parent_name() == 'foo'
assert p.get_real_name() == 'bar'
assert p.get_alias() == 'baz'


def test_aliased_literal_without_as():
p = sqlparse.parse('1 foo')[0].tokens
assert len(p) == 1
assert p[0].get_alias() == 'foo'

0 comments on commit 2d72b7a

Please sign in to comment.