Skip to content

Commit

Permalink
Merge pull request #228 from compareasiagroup/master
Browse files Browse the repository at this point in the history
Fix for #227 (get_type() doesn't work for queries that use WITH)
  • Loading branch information
andialbrecht committed Mar 7, 2016
2 parents 56e72ac + f516b66 commit 8f39d33
Show file tree
Hide file tree
Showing 4 changed files with 27 additions and 1 deletion.
2 changes: 1 addition & 1 deletion sqlparse/keywords.py
Original file line number Diff line number Diff line change
Expand Up @@ -486,7 +486,7 @@
'VOLATILE': tokens.Keyword,

'WHENEVER': tokens.Keyword,
'WITH': tokens.Keyword,
'WITH': tokens.Keyword.CTE,
'WITHOUT': tokens.Keyword,
'WORK': tokens.Keyword,
'WRITE': tokens.Keyword,
Expand Down
12 changes: 12 additions & 0 deletions sqlparse/sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -487,6 +487,18 @@ def get_type(self):
elif first_token.ttype in (T.Keyword.DML, T.Keyword.DDL):
return first_token.normalized

elif first_token.ttype == T.Keyword.CTE:
# The WITH keyword should be followed by either an Identifier or
# an IdentifierList containing the CTE definitions; the actual
# DML keyword (e.g. SELECT, INSERT) will follow next.
idents = self.token_next(self.token_index(first_token), skip_ws=True)
if isinstance(idents, (Identifier, IdentifierList)):
dml_keyword = self.token_next(self.token_index(idents), skip_ws=True)
if dml_keyword.ttype == T.Keyword.DML:
return dml_keyword.normalized
# Hmm, probably invalid syntax, so return unknown.
return 'UNKNOWN'

return 'UNKNOWN'


Expand Down
1 change: 1 addition & 0 deletions sqlparse/tokens.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ def __repr__(self):
# SQL specific tokens
DML = Keyword.DML
DDL = Keyword.DDL
CTE = Keyword.CTE
Command = Keyword.Command

Group = Token.Group
Expand Down
13 changes: 13 additions & 0 deletions tests/test_regressions.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,3 +291,16 @@ def test_issue212_py2unicode():
def test_issue213_leadingws():
sql = " select * from foo"
assert sqlparse.format(sql, strip_whitespace=True) == "select * from foo"


def test_issue227_gettype_cte():
select_stmt = sqlparse.parse('SELECT 1, 2, 3 FROM foo;')
assert select_stmt[0].get_type() == 'SELECT'
with_stmt = sqlparse.parse('WITH foo AS (SELECT 1, 2, 3) SELECT * FROM foo;')
assert with_stmt[0].get_type() == 'SELECT'
with2_stmt = sqlparse.parse('''
WITH foo AS (SELECT 1 AS abc, 2 AS def),
bar AS (SELECT * FROM something WHERE x > 1)
INSERT INTO elsewhere SELECT * FROM foo JOIN bar;
''')
assert with2_stmt[0].get_type() == 'INSERT'

0 comments on commit 8f39d33

Please sign in to comment.