Skip to content

Commit

Permalink
feat: improve _extract_tables_from_sql (apache#26748)
Browse files Browse the repository at this point in the history
  • Loading branch information
betodealmeida authored and EandrewJones committed Apr 5, 2024
1 parent a04ad10 commit dbc5140
Show file tree
Hide file tree
Showing 4 changed files with 54 additions and 8 deletions.
19 changes: 16 additions & 3 deletions superset/sql_parse.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@

import sqlglot
import sqlparse
from flask_babel import gettext as __
from sqlalchemy import and_
from sqlglot import exp, parse, parse_one
from sqlglot.dialects.dialect import Dialect, Dialects
Expand Down Expand Up @@ -58,7 +59,12 @@
)
from sqlparse.utils import imt

from superset.exceptions import QueryClauseValidationException, SupersetParseError
from superset.errors import ErrorLevel, SupersetError, SupersetErrorType
from superset.exceptions import (
QueryClauseValidationException,
SupersetParseError,
SupersetSecurityException,
)
from superset.utils.backports import StrEnum

try:
Expand Down Expand Up @@ -467,9 +473,16 @@ def _extract_tables_from_sql(self) -> set[Table]:
"""
try:
statements = parse(self.stripped(), dialect=self._dialect)
except SqlglotError:
except SqlglotError as ex:
logger.warning("Unable to parse SQL (%s): %s", self._dialect, self.sql)
return set()
dialect = self._dialect or "generic"
raise SupersetSecurityException(
SupersetError(
error_type=SupersetErrorType.QUERY_SECURITY_ACCESS_ERROR,
message=__(f"Unable to parse SQL ({dialect}): {self.sql}"),
level=ErrorLevel.ERROR,
)
) from ex

return {
table
Expand Down
8 changes: 8 additions & 0 deletions tests/unit_tests/jinja_context_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,14 @@ def test_dataset_macro(mocker: MockFixture) -> None:
)
DatasetDAO = mocker.patch("superset.daos.dataset.DatasetDAO")
DatasetDAO.find_by_id.return_value = dataset
mocker.patch(
"superset.connectors.sqla.models.security_manager.get_guest_rls_filters",
return_value=[],
)
mocker.patch(
"superset.models.helpers.security_manager.get_guest_rls_filters",
return_value=[],
)

assert (
dataset_macro(1)
Expand Down
1 change: 1 addition & 0 deletions tests/unit_tests/security/manager_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,6 +202,7 @@ def test_raise_for_access_query_default_schema(
sm = SupersetSecurityManager(appbuilder)
mocker.patch.object(sm, "can_access_database", return_value=False)
mocker.patch.object(sm, "get_schema_perm", return_value="[PostgreSQL].[public]")
mocker.patch.object(sm, "is_guest_user", return_value=False)
SqlaTable = mocker.patch("superset.connectors.sqla.models.SqlaTable")
SqlaTable.query_datasources_by_name.return_value = []

Expand Down
34 changes: 29 additions & 5 deletions tests/unit_tests/sql_parse_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,10 @@
from sqlparse.sql import Identifier, Token, TokenList
from sqlparse.tokens import Name

from superset.exceptions import QueryClauseValidationException
from superset.exceptions import (
QueryClauseValidationException,
SupersetSecurityException,
)
from superset.sql_parse import (
add_table_name,
extract_table_references,
Expand Down Expand Up @@ -267,13 +270,34 @@ def test_extract_tables_illdefined() -> None:
"""
Test that ill-defined tables return an empty set.
"""
assert extract_tables("SELECT * FROM schemaname.") == set()
assert extract_tables("SELECT * FROM catalogname.schemaname.") == set()
assert extract_tables("SELECT * FROM catalogname..") == set()
with pytest.raises(SupersetSecurityException) as excinfo:
extract_tables("SELECT * FROM schemaname.")
assert (
str(excinfo.value) == "Unable to parse SQL (generic): SELECT * FROM schemaname."
)

with pytest.raises(SupersetSecurityException) as excinfo:
extract_tables("SELECT * FROM catalogname.schemaname.")
assert (
str(excinfo.value)
== "Unable to parse SQL (generic): SELECT * FROM catalogname.schemaname."
)

with pytest.raises(SupersetSecurityException) as excinfo:
extract_tables("SELECT * FROM catalogname..")
assert (
str(excinfo.value)
== "Unable to parse SQL (generic): SELECT * FROM catalogname.."
)

with pytest.raises(SupersetSecurityException) as excinfo:
extract_tables('SELECT * FROM "tbname')
assert str(excinfo.value) == 'Unable to parse SQL (generic): SELECT * FROM "tbname'

# odd edge case that works
assert extract_tables("SELECT * FROM catalogname..tbname") == {
Table(table="tbname", schema=None, catalog="catalogname")
}
assert extract_tables('SELECT * FROM "tbname') == set()


def test_extract_tables_show_tables_from() -> None:
Expand Down

0 comments on commit dbc5140

Please sign in to comment.