Skip to content

Commit

Permalink
fix(sqla): apply jinja to metrics (#19565)
Browse files Browse the repository at this point in the history
  • Loading branch information
villebro committed Apr 7, 2022
1 parent db21351 commit 34b5576
Showing 1 changed file with 52 additions and 31 deletions.
83 changes: 52 additions & 31 deletions superset/connectors/sqla/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -335,6 +335,7 @@ def get_timestamp_expression(
:param time_grain: Optional time grain, e.g. P1Y
:param label: alias/label that column is expected to have
:param template_processor: template processor
:return: A TimeExpression object wrapped in a Label if supported by db
"""
label = label or utils.DTTM_ALIAS
Expand Down Expand Up @@ -488,6 +489,27 @@ def data(self) -> Dict[str, Any]:
)


def _process_sql_expression(
expression: Optional[str],
database_id: int,
schema: str,
template_processor: Optional[BaseTemplateProcessor],
) -> Optional[str]:
if template_processor and expression:
expression = template_processor.process_template(expression)
if expression:
expression = validate_adhoc_subquery(
expression,
database_id,
schema,
)
try:
expression = sanitize_clause(expression)
except QueryClauseValidationException as ex:
raise QueryObjectValidationError(ex.message) from ex
return expression


class SqlaTable(Model, BaseDatasource): # pylint: disable=too-many-public-methods
"""An ORM object for SqlAlchemy table references"""

Expand Down Expand Up @@ -875,13 +897,17 @@ def get_rendered_sql(
return sql

def adhoc_metric_to_sqla(
self, metric: AdhocMetric, columns_by_name: Dict[str, TableColumn]
self,
metric: AdhocMetric,
columns_by_name: Dict[str, TableColumn],
template_processor: Optional[BaseTemplateProcessor] = None,
) -> ColumnElement:
"""
Turn an adhoc metric into a sqlalchemy column.
:param dict metric: Adhoc metric definition
:param dict columns_by_name: Columns for the current table
:param template_processor: template_processor instance
:returns: The metric defined as a sqlalchemy column
:rtype: sqlalchemy.sql.column
"""
Expand All @@ -898,17 +924,12 @@ def adhoc_metric_to_sqla(
sqla_column = column(column_name)
sqla_metric = self.sqla_aggregations[metric["aggregate"]](sqla_column)
elif expression_type == utils.AdhocMetricExpressionType.SQL:
tp = self.get_template_processor()
expression = tp.process_template(cast(str, metric["sqlExpression"]))
expression = validate_adhoc_subquery(
expression,
self.database_id,
self.schema,
expression = _process_sql_expression(
expression=metric["sqlExpression"],
database_id=self.database_id,
schema=self.schema,
template_processor=template_processor,
)
try:
expression = sanitize_clause(expression)
except QueryClauseValidationException as ex:
raise QueryObjectValidationError(ex.message) from ex
sqla_metric = literal_column(expression)
else:
raise QueryObjectValidationError("Adhoc metric expressionType is invalid")
Expand All @@ -929,21 +950,14 @@ def adhoc_column_to_sqla(
:rtype: sqlalchemy.sql.column
"""
label = utils.get_column_name(col)
expression = col["sqlExpression"]
if template_processor and expression:
expression = template_processor.process_template(expression)
if expression:
expression = validate_adhoc_subquery(
expression,
self.database_id,
self.schema,
)
try:
expression = sanitize_clause(expression)
except QueryClauseValidationException as ex:
raise QueryObjectValidationError(ex.message) from ex
sqla_metric = literal_column(expression)
return self.make_sqla_column_compatible(sqla_metric, label)
expression = _process_sql_expression(
expression=col["sqlExpression"],
database_id=self.database_id,
schema=self.schema,
template_processor=template_processor,
)
sqla_column = literal_column(expression)
return self.make_sqla_column_compatible(sqla_column, label)

def make_sqla_column_compatible(
self, sqla_col: ColumnElement, label: Optional[str] = None
Expand Down Expand Up @@ -1127,7 +1141,13 @@ def get_sqla_query( # pylint: disable=too-many-arguments,too-many-locals,too-ma
for metric in metrics:
if utils.is_adhoc_metric(metric):
assert isinstance(metric, dict)
metrics_exprs.append(self.adhoc_metric_to_sqla(metric, columns_by_name))
metrics_exprs.append(
self.adhoc_metric_to_sqla(
metric=metric,
columns_by_name=columns_by_name,
template_processor=template_processor,
)
)
elif isinstance(metric, str) and metric in metrics_by_name:
metrics_exprs.append(metrics_by_name[metric].get_sqla_col())
else:
Expand All @@ -1154,10 +1174,11 @@ def get_sqla_query( # pylint: disable=too-many-arguments,too-many-locals,too-ma
if isinstance(col, dict):
col = cast(AdhocMetric, col)
if col.get("sqlExpression"):
col["sqlExpression"] = validate_adhoc_subquery(
cast(str, col["sqlExpression"]),
self.database_id,
self.schema,
col["sqlExpression"] = _process_sql_expression(
expression=col["sqlExpression"],
database_id=self.database_id,
schema=self.schema,
template_processor=template_processor,
)
if utils.is_adhoc_metric(col):
# add adhoc sort by column to columns_by_name if not exists
Expand Down

0 comments on commit 34b5576

Please sign in to comment.