Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(sql_parse): Ensure table extraction handles Jinja templating #27470

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion superset/commands/sql_lab/execute.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,11 +144,13 @@ def _run_sql_json_exec_from_scratch(self) -> SqlJsonExecutionStatus:
try:
logger.info("Triggering query_id: %i", query.id)

# Necessary to check access before rendering the Jinjafied query as the
# some Jinja macros execute statements upon rendering.
self._validate_access(query)
self._execution_context.set_query(query)
rendered_query = self._sql_query_render.render(self._execution_context)
validate_rendered_query = copy.copy(query)
validate_rendered_query.sql = rendered_query
self._validate_access(validate_rendered_query)
self._set_query_limit_if_required(rendered_query)
self._query_dao.update(
query, {"limit": self._execution_context.query.limit}
Expand Down
10 changes: 5 additions & 5 deletions superset/jinja_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
import dateutil
from flask import current_app, has_request_context, request
from flask_babel import gettext as _
from jinja2 import DebugUndefined
from jinja2 import DebugUndefined, Environment
from jinja2.sandbox import SandboxedEnvironment
from sqlalchemy.engine.interfaces import Dialect
from sqlalchemy.sql.expression import bindparam
Expand Down Expand Up @@ -479,11 +479,11 @@ def __init__(
self._applied_filters = applied_filters
self._removed_filters = removed_filters
self._context: dict[str, Any] = {}
self._env = SandboxedEnvironment(undefined=DebugUndefined)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Given that this PR (and other logic) references _env I thought it should be made "public"—in Presto's pseudo public/private naming sense.

self.env: Environment = SandboxedEnvironment(undefined=DebugUndefined)
self.set_context(**kwargs)

# custom filters
self._env.filters["where_in"] = WhereInMacro(database.get_dialect())
self.env.filters["where_in"] = WhereInMacro(database.get_dialect())

def set_context(self, **kwargs: Any) -> None:
self._context.update(kwargs)
Expand All @@ -496,7 +496,7 @@ def process_template(self, sql: str, **kwargs: Any) -> str:
>>> process_template(sql)
"SELECT '2017-01-01T00:00:00'"
"""
template = self._env.from_string(sql)
template = self.env.from_string(sql)
kwargs.update(self._context)

context = validate_template_context(self.engine, kwargs)
Expand Down Expand Up @@ -643,7 +643,7 @@ class TrinoTemplateProcessor(PrestoTemplateProcessor):
engine = "trino"

def process_template(self, sql: str, **kwargs: Any) -> str:
template = self._env.from_string(sql)
template = self.env.from_string(sql)
kwargs.update(self._context)

# Backwards compatibility if migrating from Presto.
Expand Down
40 changes: 27 additions & 13 deletions superset/models/sql_lab.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,14 +46,15 @@
from sqlalchemy.sql.elements import ColumnElement, literal_column

from superset import security_manager
from superset.exceptions import SupersetSecurityException
from superset.jinja_context import BaseTemplateProcessor, get_template_processor
from superset.models.helpers import (
AuditMixinNullable,
ExploreMixin,
ExtraJSONMixin,
ImportExportMixin,
)
from superset.sql_parse import CtasMethod, ParsedQuery, Table
from superset.sql_parse import CtasMethod, extract_tables_from_jinja_sql, Table
from superset.sqllab.limiting_factor import LimitingFactor
from superset.utils.core import get_column_name, MediumText, QueryStatus, user_label

Expand All @@ -65,8 +66,25 @@
logger = logging.getLogger(__name__)


class SqlTablesMixin: # pylint: disable=too-few-public-methods
@property
def sql_tables(self) -> list[Table]:
try:
return list(
extract_tables_from_jinja_sql(
self.sql, # type: ignore
self.database.db_engine_spec.engine, # type: ignore
)
)
except SupersetSecurityException:
return []

Check warning on line 80 in superset/models/sql_lab.py

View check run for this annotation

Codecov / codecov/patch

superset/models/sql_lab.py#L79-L80

Added lines #L79 - L80 were not covered by tests


class Query(
ExtraJSONMixin, ExploreMixin, Model
SqlTablesMixin,
ExtraJSONMixin,
ExploreMixin,
Model,
): # pylint: disable=abstract-method,too-many-public-methods
"""ORM model for SQL query

Expand Down Expand Up @@ -181,10 +199,6 @@
def username(self) -> str:
return self.user.username

@property
def sql_tables(self) -> list[Table]:
return list(ParsedQuery(self.sql, engine=self.db_engine_spec.engine).tables)

@property
def columns(self) -> list["TableColumn"]:
from superset.connectors.sqla.models import ( # pylint: disable=import-outside-toplevel
Expand Down Expand Up @@ -355,7 +369,13 @@
return self.make_sqla_column_compatible(sqla_column, label)


class SavedQuery(AuditMixinNullable, ExtraJSONMixin, ImportExportMixin, Model):
class SavedQuery(
SqlTablesMixin,
AuditMixinNullable,
ExtraJSONMixin,
ImportExportMixin,
Model,
):
"""ORM model for SQL query"""

__tablename__ = "saved_query"
Expand Down Expand Up @@ -425,12 +445,6 @@
def url(self) -> str:
return f"/sqllab?savedQueryId={self.id}"

@property
def sql_tables(self) -> list[Table]:
return list(
ParsedQuery(self.sql, engine=self.database.db_engine_spec.engine).tables
)

@property
def last_run_humanized(self) -> str:
return naturaltime(datetime.now() - self.changed_on)
Expand Down
13 changes: 4 additions & 9 deletions superset/security/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,14 +52,12 @@
from sqlalchemy.orm.mapper import Mapper
from sqlalchemy.orm.query import Query as SqlaQuery

from superset import sql_parse
from superset.constants import RouteMethod
from superset.errors import ErrorLevel, SupersetError, SupersetErrorType
from superset.exceptions import (
DatasetInvalidPermissionEvaluationException,
SupersetSecurityException,
)
from superset.jinja_context import get_template_processor
from superset.security.guest_token import (
GuestToken,
GuestTokenResources,
Expand All @@ -68,6 +66,7 @@
GuestTokenUser,
GuestUser,
)
from superset.sql_parse import extract_tables_from_jinja_sql
from superset.superset_typing import Metric
from superset.utils.core import (
DatasourceName,
Expand Down Expand Up @@ -1961,16 +1960,12 @@ def raise_for_access(
return

if query:
# make sure the quuery is valid SQL by rendering any Jinja
processor = get_template_processor(database=query.database)
rendered_sql = processor.process_template(query.sql)
default_schema = database.get_default_schema_for_query(query)
tables = {
Table(table_.table, table_.schema or default_schema)
for table_ in sql_parse.ParsedQuery(
rendered_sql,
engine=database.db_engine_spec.engine,
).tables
for table_ in extract_tables_from_jinja_sql(
query.sql, database.db_engine_spec.engine
)
}
elif table:
tables = {table}
Expand Down
60 changes: 60 additions & 0 deletions superset/sql_parse.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,12 @@
from collections.abc import Iterable, Iterator
from dataclasses import dataclass
from typing import Any, cast
from unittest.mock import Mock

import sqlglot
import sqlparse
from flask_babel import gettext as __
from jinja2 import nodes
from sqlalchemy import and_
from sqlglot import exp, parse, parse_one
from sqlglot.dialects.dialect import Dialect, Dialects
Expand Down Expand Up @@ -1232,3 +1234,61 @@ def find_nodes_by_key(element: Any, target: str) -> Iterator[Any]:
Table(*[part["value"] for part in table["name"][::-1]])
for table in find_nodes_by_key(tree, "Table")
}


def extract_tables_from_jinja_sql(sql: str, engine: str | None = None) -> set[Table]:
"""
Extract all table references in the Jinjafied SQL statement.

Due to Jinja templating, a multiphase approach is necessary as the Jinjafied SQL
statement may represent invalid SQL which is non-parsable by SQLGlot.

Firstly, we extract any tables referenced within the confines of specific Jinja
macros. Secondly, we replace these non-SQL Jinja calls with a pseudo-benign SQL
expression to help ensure that the resulting SQL statements are parsable by
SQLGlot.

:param sql: The Jinjafied SQL statement
:param engine: The associated database engine
:returns: The set of tables referenced in the SQL statement
:raises SupersetSecurityException: If SQLGlot is unable to parse the SQL statement
"""

from superset.jinja_context import ( # pylint: disable=import-outside-toplevel
get_template_processor,
)

# Mock the required database as the processor signature is exposed publically.
processor = get_template_processor(database=Mock(backend=engine))
template = processor.env.parse(sql)

tables = set()

for node in template.find_all(nodes.Call):
if isinstance(node.node, nodes.Getattr) and node.node.attr in (
"latest_partition",
"latest_sub_partition",
):
# Extract the table referenced in the macro.
tables.add(
Table(
*[
remove_quotes(part)
for part in node.args[0].value.split(".")[::-1]
if len(node.args) == 1
]
)
)

# Replace the potentially problematic Jinja macro with some benign SQL.
node.__class__ = nodes.TemplateData
node.fields = nodes.TemplateData.fields
node.data = "NULL"

return (
tables
| ParsedQuery(
sql_statement=processor.process_template(template),
engine=engine,
).tables
)
3 changes: 1 addition & 2 deletions superset/sqllab/query_render.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,8 +79,7 @@ def _validate(
sql_template_processor: BaseTemplateProcessor,
) -> None:
if is_feature_enabled("ENABLE_TEMPLATE_PROCESSING"):
# pylint: disable=protected-access
syntax_tree = sql_template_processor._env.parse(rendered_query)
syntax_tree = sql_template_processor.env.parse(rendered_query)
undefined_parameters = find_undeclared_variables(syntax_tree)
if undefined_parameters:
self._raise_undefined_parameter_exception(
Expand Down
41 changes: 41 additions & 0 deletions tests/unit_tests/sql_parse_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
from superset.sql_parse import (
add_table_name,
extract_table_references,
extract_tables_from_jinja_sql,
get_rls_for_table,
has_table_query,
insert_rls_as_subquery,
Expand Down Expand Up @@ -1909,3 +1910,43 @@ def test_sqlstatement() -> None:

statement = SQLStatement("SET a=1")
assert statement.get_settings() == {"a": "1"}


@pytest.mark.parametrize(
"engine",
[
"hive",
"presto",
"trino",
],
)
@pytest.mark.parametrize(
"macro",
[
"latest_partition('foo.bar')",
"latest_sub_partition('foo.bar', baz='qux')",
],
)
@pytest.mark.parametrize(
"sql,expected",
[
(
"SELECT '{{{{ {engine}.{macro} }}}}'",
{Table(table="bar", schema="foo")},
),
(
"SELECT * FROM foo.baz WHERE quux = '{{{{ {engine}.{macro} }}}}'",
{Table(table="bar", schema="foo"), Table(table="baz", schema="foo")},
),
],
)
def test_extract_tables_from_jinja_sql(
engine: str,
macro: str,
sql: str,
expected: set[Table],
) -> None:
assert (
extract_tables_from_jinja_sql(sql.format(engine=engine, macro=macro), engine)
== expected
)
Loading