From d8e8f30137a06f3e1b92993c2ec69f000a76e9ec Mon Sep 17 00:00:00 2001 From: Andrew Tipton Date: Wed, 2 Mar 2016 18:20:12 +0800 Subject: [PATCH 1/2] Add failing test for issue #227. --- tests/test_regressions.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/tests/test_regressions.py b/tests/test_regressions.py index ca7dd5b2..acadefff 100644 --- a/tests/test_regressions.py +++ b/tests/test_regressions.py @@ -291,3 +291,10 @@ def test_issue212_py2unicode(): def test_issue213_leadingws(): sql = " select * from foo" assert sqlparse.format(sql, strip_whitespace=True) == "select * from foo" + + +def test_issue227_gettype_cte(): + select_stmt = sqlparse.parse('SELECT 1, 2, 3 FROM foo;')[0] + assert select_stmt.get_type() == 'SELECT' + with_stmt = sqlparse.parse('WITH foo AS (SELECT 1, 2, 3) SELECT * FROM foo;')[0] + assert with_stmt.get_type() == 'SELECT' From f516b66a0e254af510b6b8a18510aad922d69701 Mon Sep 17 00:00:00 2001 From: Andrew Tipton Date: Wed, 2 Mar 2016 18:37:28 +0800 Subject: [PATCH 2/2] Ensure get_type() works for queries that use WITH. --- sqlparse/keywords.py | 2 +- sqlparse/sql.py | 12 ++++++++++++ sqlparse/tokens.py | 1 + tests/test_regressions.py | 14 ++++++++++---- 4 files changed, 24 insertions(+), 5 deletions(-) diff --git a/sqlparse/keywords.py b/sqlparse/keywords.py index b6c4246f..dd08be0e 100644 --- a/sqlparse/keywords.py +++ b/sqlparse/keywords.py @@ -486,7 +486,7 @@ 'VOLATILE': tokens.Keyword, 'WHENEVER': tokens.Keyword, - 'WITH': tokens.Keyword, + 'WITH': tokens.Keyword.CTE, 'WITHOUT': tokens.Keyword, 'WORK': tokens.Keyword, 'WRITE': tokens.Keyword, diff --git a/sqlparse/sql.py b/sqlparse/sql.py index c1111bb4..9c0497a4 100644 --- a/sqlparse/sql.py +++ b/sqlparse/sql.py @@ -487,6 +487,18 @@ def get_type(self): elif first_token.ttype in (T.Keyword.DML, T.Keyword.DDL): return first_token.normalized + elif first_token.ttype == T.Keyword.CTE: + # The WITH keyword should be followed by either an Identifier or + # an IdentifierList containing the CTE definitions; the actual + # DML keyword (e.g. SELECT, INSERT) will follow next. + idents = self.token_next(self.token_index(first_token), skip_ws=True) + if isinstance(idents, (Identifier, IdentifierList)): + dml_keyword = self.token_next(self.token_index(idents), skip_ws=True) + if dml_keyword.ttype == T.Keyword.DML: + return dml_keyword.normalized + # Hmm, probably invalid syntax, so return unknown. + return 'UNKNOWN' + return 'UNKNOWN' diff --git a/sqlparse/tokens.py b/sqlparse/tokens.py index 01a9b896..98fa8a62 100644 --- a/sqlparse/tokens.py +++ b/sqlparse/tokens.py @@ -75,6 +75,7 @@ def __repr__(self): # SQL specific tokens DML = Keyword.DML DDL = Keyword.DDL +CTE = Keyword.CTE Command = Keyword.Command Group = Token.Group diff --git a/tests/test_regressions.py b/tests/test_regressions.py index acadefff..c66a42d9 100644 --- a/tests/test_regressions.py +++ b/tests/test_regressions.py @@ -294,7 +294,13 @@ def test_issue213_leadingws(): def test_issue227_gettype_cte(): - select_stmt = sqlparse.parse('SELECT 1, 2, 3 FROM foo;')[0] - assert select_stmt.get_type() == 'SELECT' - with_stmt = sqlparse.parse('WITH foo AS (SELECT 1, 2, 3) SELECT * FROM foo;')[0] - assert with_stmt.get_type() == 'SELECT' + select_stmt = sqlparse.parse('SELECT 1, 2, 3 FROM foo;') + assert select_stmt[0].get_type() == 'SELECT' + with_stmt = sqlparse.parse('WITH foo AS (SELECT 1, 2, 3) SELECT * FROM foo;') + assert with_stmt[0].get_type() == 'SELECT' + with2_stmt = sqlparse.parse(''' + WITH foo AS (SELECT 1 AS abc, 2 AS def), + bar AS (SELECT * FROM something WHERE x > 1) + INSERT INTO elsewhere SELECT * FROM foo JOIN bar; + ''') + assert with2_stmt[0].get_type() == 'INSERT'