Skip to content

Commit

Permalink
Add CUSTOM_TEMPLATE_PROCESSOR config
Browse files Browse the repository at this point in the history
  • Loading branch information
Dandan Shi committed Apr 6, 2020
1 parent 801e2f1 commit 8e63194
Show file tree
Hide file tree
Showing 9 changed files with 258 additions and 2 deletions.
53 changes: 53 additions & 0 deletions docs/installation.rst
Original file line number Diff line number Diff line change
Expand Up @@ -1085,6 +1085,59 @@ in this dictionary are made available for users to use in their SQL.
'my_crazy_macro': lambda x: x*2,
}
Besides default Jinja templating, SQL lab also supports self-defined template
processor by setting the ``CUSTOM_TEMPLATE_PROCESSORS`` in your superset configuration.
The values in this dictionary overwrite the default Jinja template processors of the
specified database engine.
The example below configures a custom presto template processor which implements
its own logic of processing macro template with regex parsing. It uses ``$`` style
macro instead of ``{{ }}`` style in Jinja templating. By configuring it with
``CUSTOM_TEMPLATE_PROCESSORS``, sql template on presto database is processed
by the custom one rather than the default one.

.. code-block:: python
def DATE(
ts: datetime, day_offset: SupportsInt = 0, hour_offset: SupportsInt = 0
) -> str:
"""Current day as a string."""
day_offset, hour_offset = int(day_offset), int(hour_offset)
offset_day = (ts + timedelta(days=day_offset, hours=hour_offset)).date()
return str(offset_day)
class CustomPrestoTemplateProcessor(PrestoTemplateProcessor):
"""A custom presto template processor."""
engine = "presto"
def process_template(self, sql: str, **kwargs) -> str:
"""Processes a sql template with $ style macro using regex."""
# Add custom macros functions.
macros = {
"DATE": partial(DATE, datetime.utcnow())
} # type: Dict[str, Any]
# Update with macros defined in context and kwargs.
macros.update(self.context)
macros.update(kwargs)
def replacer(match):
"""Expand $ style macros with corresponding function calls."""
macro_name, args_str = match.groups()
args = [a.strip() for a in args_str.split(",")]
if args == [""]:
args = []
f = macros[macro_name[1:]]
return f(*args)
macro_names = ["$" + name for name in macros.keys()]
pattern = r"(%s)\s*\(([^()]*)\)" % "|".join(map(re.escape, macro_names))
return re.sub(pattern, replacer, sql)
CUSTOM_TEMPLATE_PROCESSORS = {
CustomPrestoTemplateProcessor.engine: CustomPrestoTemplateProcessor
}
SQL Lab also includes a live query validation feature with pluggable backends.
You can configure which validation implementation is used with which database
engine by adding a block like the following to your config.py:
Expand Down
9 changes: 9 additions & 0 deletions docs/sqllab.rst
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,15 @@ environment using the configuration variable ``JINJA_CONTEXT_ADDONS``.
All objects referenced in this dictionary will become available for users
to integrate in their queries in **SQL Lab**.

Customize templating
''''''''''''''''''''

As mentioned in the `Installation & Configuration <https://superset.incubator.apache.org/installation.html#sql-lab>`__ documentation,
it's possible for administrators to overwrite Jinja templating with your customized
template processor using the configuration variable ``CUSTOM_TEMPLATE_PROCESSORS``.
The template processors referenced in the dictionary will overwrite default Jinja template processors
of the specified database engines.

Query cost estimation
'''''''''''''''''''''

Expand Down
10 changes: 10 additions & 0 deletions superset/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,9 @@
from dateutil import tz
from flask_appbuilder.security.manager import AUTH_DB

from superset.jinja_context import ( # pylint: disable=unused-import
BaseTemplateProcessor,
)
from superset.stats_logger import DummyStatsLogger
from superset.typing import CacheConfig
from superset.utils.log import DBEventLogger
Expand Down Expand Up @@ -585,6 +588,13 @@ class CeleryConfig: # pylint: disable=too-few-public-methods
# dictionary.
JINJA_CONTEXT_ADDONS: Dict[str, Callable] = {}

# A dictionary of macro template processors that gets merged into global
# template processors. The existing template processors get updated with this
# dictionary, which means the existing keys get overwritten by the content of this
# dictionary. The customized addons don't necessarily need to use jinjia templating
# language. This allows you to define custom logic to process macro template.
CUSTOM_TEMPLATE_PROCESSORS = {} # type: Dict[str, BaseTemplateProcessor]

# Roles that are controlled by the API / Superset and should not be changes
# by humans.
ROBOT_PERMISSION_ROLES = ["Public", "Gamma", "Alpha", "Admin", "sql_lab"]
Expand Down
15 changes: 14 additions & 1 deletion superset/extensions.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import time
import uuid
from datetime import datetime, timedelta
from typing import Dict, TYPE_CHECKING # pylint: disable=unused-import

import celery
from dateutil.relativedelta import relativedelta
Expand All @@ -31,6 +32,12 @@
from superset.utils.cache_manager import CacheManager
from superset.utils.feature_flag_manager import FeatureFlagManager

# Avoid circular import
if TYPE_CHECKING:
from superset.jinja_context import ( # pylint: disable=unused-import
BaseTemplateProcessor,
)


class JinjaContextManager:
def __init__(self) -> None:
Expand All @@ -42,14 +49,20 @@ def __init__(self) -> None:
"timedelta": timedelta,
"uuid": uuid,
}
self._template_processors = {} # type: Dict[str, BaseTemplateProcessor]

def init_app(self, app):
self._base_context.update(app.config["JINJA_CONTEXT_ADDONS"])
self._template_processors.update(app.config["CUSTOM_TEMPLATE_PROCESSORS"])

@property
def base_context(self):
return self._base_context

@property
def template_processors(self):
return self._template_processors


class ResultsBackendManager:
def __init__(self) -> None:
Expand Down Expand Up @@ -120,7 +133,7 @@ def get_manifest_files(self, bundle, asset_type):
_event_logger: dict = {}
event_logger = LocalProxy(lambda: _event_logger.get("event_logger"))
feature_flag_manager = FeatureFlagManager()
jinja_context_manager = JinjaContextManager()
jinja_context_manager = JinjaContextManager() # type: JinjaContextManager
manifest_processor = UIManifestProcessor(APP_DIR)
migrate = Migrate()
results_backend_manager = ResultsBackendManager()
Expand Down
4 changes: 3 additions & 1 deletion superset/jinja_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from jinja2.sandbox import SandboxedEnvironment

from superset import jinja_base_context
from superset.extensions import jinja_context_manager


def url_param(param: str, default: Optional[str] = None) -> Optional[Any]:
Expand Down Expand Up @@ -263,7 +264,8 @@ class HiveTemplateProcessor(PrestoTemplateProcessor):
engine = "hive"


template_processors = {}
# The global template processors from Jinja context manager.
template_processors = jinja_context_manager.template_processors
keys = tuple(globals().keys())
for k in keys:
o = globals()[k]
Expand Down
22 changes: 22 additions & 0 deletions tests/base_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,6 +282,28 @@ def delete_fake_db(self):
if database:
db.session.delete(database)

def create_fake_presto_db(self):
self.login(username="admin")
database_name = "presto"
db_id = 200
return self.get_or_create(
cls=models.Database,
criteria={"database_name": database_name},
session=db.session,
sqlalchemy_uri="presto://user@host:8080/hive",
id=db_id,
)

def delete_fake_presto_db(self):
database = (
db.session.query(Database)
.filter(Database.database_name == "presto")
.scalar()
)
if database:
db.session.delete(database)
db.session.commit()

def validate_sql(
self,
sql,
Expand Down
83 changes: 83 additions & 0 deletions tests/core_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -668,6 +668,89 @@ def test_templated_sql_json(self):
data = self.run_sql(sql, "fdaklj3ws")
self.assertEqual(data["data"][0]["test"], "2017-01-01T00:00:00")

@mock.patch("tests.superset_test_custom_template_processors.datetime")
def test_custom_process_template(self, mock_dt) -> None:
"""Test macro defined in custom template processor works."""
mock_dt.utcnow = mock.Mock(return_value=datetime.datetime(1970, 1, 1))
db = mock.Mock()
db.backend = "presto"
tp = jinja_context.get_template_processor(database=db)

sql = "SELECT '$DATE()'"
rendered = tp.process_template(sql)
self.assertEqual("SELECT '{}'".format("1970-01-01"), rendered)

sql = "SELECT '$DATE(1, 2)'"
rendered = tp.process_template(sql)
self.assertEqual("SELECT '{}'".format("1970-01-02"), rendered)

def test_custom_get_template_kwarg(self):
"""Test macro passed as kwargs when getting template processor
works in custom template processor."""
db = mock.Mock()
db.backend = "presto"
s = "$foo()"
tp = jinja_context.get_template_processor(database=db, foo=lambda: "bar")
rendered = tp.process_template(s)
self.assertEqual("bar", rendered)

def test_custom_template_kwarg(self) -> None:
"""Test macro passed as kwargs when processing template
works in custom template processor."""
db = mock.Mock()
db.backend = "presto"
s = "$foo()"
tp = jinja_context.get_template_processor(database=db)
rendered = tp.process_template(s, foo=lambda: "bar")
self.assertEqual("bar", rendered)

def test_custom_template_processors_overwrite(self) -> None:
"""Test template processor for presto gets overwritten by custom one."""
db = mock.Mock()
db.backend = "presto"
tp = jinja_context.get_template_processor(database=db)

sql = "SELECT '{{ datetime(2017, 1, 1).isoformat() }}'"
rendered = tp.process_template(sql)
self.assertEqual(sql, rendered)

sql = "SELECT '{{ DATE(1, 2) }}'"
rendered = tp.process_template(sql)
self.assertEqual(sql, rendered)

def test_custom_template_processors_ignored(self) -> None:
"""Test custom template processor is ignored for a difference backend
database."""
maindb = utils.get_example_database()
sql = "SELECT '$DATE()'"
tp = jinja_context.get_template_processor(database=maindb)
rendered = tp.process_template(sql)
self.assertEqual(sql, rendered)

@mock.patch("tests.superset_test_custom_template_processors.datetime")
@mock.patch("superset.sql_lab.get_sql_results")
def test_custom_templated_sql_json(self, sql_lab_mock, mock_dt) -> None:
"""Test sqllab receives macros expanded query."""
mock_dt.utcnow = mock.Mock(return_value=datetime.datetime(1970, 1, 1))
self.login("admin")
sql = "SELECT '$DATE()' as test"
resp = {
"status": utils.QueryStatus.SUCCESS,
"query": {"rows": 1},
"data": [{"test": "'1970-01-01'"}],
}
sql_lab_mock.return_value = resp

dbobj = self.create_fake_presto_db()
json_payload = dict(database_id=dbobj.id, sql=sql)
self.get_json_resp(
"/superset/sql_json/", raise_on_error=False, json_=json_payload
)
assert sql_lab_mock.called
self.assertEqual(sql_lab_mock.call_args[0][1], "SELECT '1970-01-01' as test")

self.delete_fake_presto_db()

def test_fetch_datasource_metadata(self):
self.login(username="admin")
url = "/superset/fetch_datasource_metadata?" "datasourceKey=1__table"
Expand Down
5 changes: 5 additions & 0 deletions tests/superset_test_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from copy import copy

from superset.config import *
from tests.superset_test_custom_template_processors import CustomPrestoTemplateProcessor

AUTH_USER_REGISTRATION_ROLE = "alpha"
SQLALCHEMY_DATABASE_URI = "sqlite:///" + os.path.join(DATA_DIR, "unittests.db")
Expand Down Expand Up @@ -57,3 +58,7 @@ class CeleryConfig(object):


CELERY_CONFIG = CeleryConfig

CUSTOM_TEMPLATE_PROCESSORS = {
CustomPrestoTemplateProcessor.engine: CustomPrestoTemplateProcessor
}
59 changes: 59 additions & 0 deletions tests/superset_test_custom_template_processors.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

import re
from datetime import datetime, timedelta
from functools import partial
from typing import Any, Dict, SupportsInt

from superset.jinja_context import PrestoTemplateProcessor


def DATE(
ts: datetime, day_offset: SupportsInt = 0, hour_offset: SupportsInt = 0
) -> str:
"""Current day as a string"""
day_offset, hour_offset = int(day_offset), int(hour_offset)
offset_day = (ts + timedelta(days=day_offset, hours=hour_offset)).date()
return str(offset_day)


class CustomPrestoTemplateProcessor(PrestoTemplateProcessor):
"""A custom presto template processor for test."""

engine = "presto"

def process_template(self, sql: str, **kwargs) -> str:
"""Processes a sql template with $ style macro using regex."""
# Add custom macros functions.
macros = {"DATE": partial(DATE, datetime.utcnow())} # type: Dict[str, Any]
# Update with macros defined in context and kwargs.
macros.update(self.context)
macros.update(kwargs)

def replacer(match):
"""Expands $ style macros with corresponding function calls."""
macro_name, args_str = match.groups()
args = [a.strip() for a in args_str.split(",")]
if args == [""]:
args = []
f = macros[macro_name[1:]]
return f(*args)

macro_names = ["$" + name for name in macros.keys()]
pattern = r"(%s)\s*\(([^()]*)\)" % "|".join(map(re.escape, macro_names))
return re.sub(pattern, replacer, sql)

0 comments on commit 8e63194

Please sign in to comment.