Skip to content

Commit

Permalink
fix(sqlglot): Address regressions introduced in #26476 (#27217)
Browse files Browse the repository at this point in the history
  • Loading branch information
john-bodley committed Feb 23, 2024
1 parent 6d88701 commit 2c56481
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 10 deletions.
17 changes: 11 additions & 6 deletions superset/sql_parse.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
from sqlalchemy import and_
from sqlglot import exp, parse, parse_one
from sqlglot.dialects import Dialects
from sqlglot.errors import ParseError
from sqlglot.errors import SqlglotError
from sqlglot.optimizer.scope import Scope, ScopeType, traverse_scope
from sqlparse import keywords
from sqlparse.lexer import Lexer
Expand Down Expand Up @@ -287,7 +287,7 @@ def _extract_tables_from_sql(self) -> set[Table]:
"""
try:
statements = parse(self.stripped(), dialect=self._dialect)
except ParseError:
except SqlglotError:
logger.warning("Unable to parse SQL (%s): %s", self._dialect, self.sql)
return set()

Expand Down Expand Up @@ -319,12 +319,17 @@ def _extract_tables_from_statement(self, statement: exp.Expression) -> set[Table
elif isinstance(statement, exp.Command):
# Commands, like `SHOW COLUMNS FROM foo`, have to be converted into a
# `SELECT` statetement in order to extract tables.
literal = statement.find(exp.Literal)
if not literal:
if not (literal := statement.find(exp.Literal)):
return set()

pseudo_query = parse_one(f"SELECT {literal.this}", dialect=self._dialect)
sources = pseudo_query.find_all(exp.Table)
try:
pseudo_query = parse_one(
f"SELECT {literal.this}",
dialect=self._dialect,
)
sources = pseudo_query.find_all(exp.Table)
except SqlglotError:
return set()
else:
sources = [
source
Expand Down
10 changes: 6 additions & 4 deletions tests/unit_tests/sql_parse_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,6 +271,7 @@ def test_extract_tables_illdefined() -> None:
assert extract_tables("SELECT * FROM catalogname..tbname") == {
Table(table="tbname", schema=None, catalog="catalogname")
}
assert extract_tables('SELECT * FROM "tbname') == set()


def test_extract_tables_show_tables_from() -> None:
Expand Down Expand Up @@ -558,6 +559,10 @@ def test_extract_tables_multistatement() -> None:
Table("t1"),
Table("t2"),
}
assert extract_tables(
"ADD JAR file:///hive.jar; SELECT * FROM t1;",
engine="hive",
) == {Table("t1")}


def test_extract_tables_complex() -> None:
Expand Down Expand Up @@ -1815,10 +1820,7 @@ def test_extract_table_references(mocker: MockerFixture) -> None:
# test falling back to sqlparse
logger = mocker.patch("superset.sql_parse.logger")
sql = "SELECT * FROM table UNION ALL SELECT * FROM other_table"
assert extract_table_references(
sql,
"trino",
) == {
assert extract_table_references(sql, "trino") == {
Table(table="table", schema=None, catalog=None),
Table(table="other_table", schema=None, catalog=None),
}
Expand Down

0 comments on commit 2c56481

Please sign in to comment.