Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[fix] Adding SIP-15 support for the query context #9219

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
5 changes: 5 additions & 0 deletions superset/common/query_object.py
Expand Up @@ -23,6 +23,7 @@

from superset import app
from superset.utils import core as utils
from superset.views.utils import get_time_range_endpoints

# TODO: Type Metrics dictionary with TypedDict when it becomes a vanilla python type
# https://github.com/python/mypy/issues/5288
Expand Down Expand Up @@ -95,6 +96,10 @@ def __init__(
self.timeseries_limit_metric = timeseries_limit_metric
self.order_desc = order_desc
self.extras = extras or {}

if app.config["SIP_15_ENABLED"] and "time_range_endpoints" not in self.extras:
self.extras["time_range_endpoints"] = get_time_range_endpoints(form_data={})

self.columns = columns or []
self.orderby = orderby or []

Expand Down
25 changes: 23 additions & 2 deletions tests/core_tests.py
Expand Up @@ -165,6 +165,17 @@ def test_cache_key_changes_when_datasource_is_updated(self):
# the new cache_key should be different due to updated datasource
self.assertNotEqual(cache_key_original, cache_key_new)

def test_query_context_time_range_endpoints(self):
query_context = QueryContext(**self._get_query_context_dict())
query_object = query_context.queries[0]
extras = query_object.to_dict()["extras"]
self.assertTrue("time_range_endpoints" in extras)

self.assertEquals(
extras["time_range_endpoints"],
(utils.TimeRangeEndpoint.INCLUSIVE, utils.TimeRangeEndpoint.EXCLUSIVE),
)

def test_get_superset_tables_not_allowed(self):
example_db = utils.get_example_database()
schema_name = self.default_schema_backend_map[example_db.backend]
Expand Down Expand Up @@ -973,7 +984,12 @@ def test_results_default_deserialization(self):
"sql": "SELECT * FROM birth_names LIMIT 100",
"status": utils.QueryStatus.PENDING,
}
serialized_data, selected_columns, all_columns, expanded_columns = sql_lab._serialize_and_expand_data(
(
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is black formatting.

serialized_data,
selected_columns,
all_columns,
expanded_columns,
) = sql_lab._serialize_and_expand_data(
results, db_engine_spec, use_new_deserialization
)
payload = {
Expand Down Expand Up @@ -1016,7 +1032,12 @@ def test_results_msgpack_deserialization(self):
"sql": "SELECT * FROM birth_names LIMIT 100",
"status": utils.QueryStatus.PENDING,
}
serialized_data, selected_columns, all_columns, expanded_columns = sql_lab._serialize_and_expand_data(
(
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is black formatting.

serialized_data,
selected_columns,
all_columns,
expanded_columns,
) = sql_lab._serialize_and_expand_data(
results, db_engine_spec, use_new_deserialization
)
payload = {
Expand Down