Skip to content

Commit

Permalink
[sip-15] Fix time range endpoints encoding (#8481)
Browse files Browse the repository at this point in the history
  • Loading branch information
john-bodley authored and Grace committed Nov 13, 2019
1 parent 706da5e commit 29507ba
Show file tree
Hide file tree
Showing 2 changed files with 53 additions and 16 deletions.
28 changes: 13 additions & 15 deletions superset/views/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,35 +209,33 @@ def get_time_range_endpoints(
form_data: Dict[str, Any], slc: Optional[models.Slice]
) -> Optional[Tuple[TimeRangeEndpoint, TimeRangeEndpoint]]:
"""
Get the slice aware time range endpoints falling back to the SQL database specific
definition or default if not defined.
Get the slice aware time range endpoints from the form-data falling back to the SQL
database specific definition or default if not defined.
For SIP-15 all new slices use the [start, end) interval which is consistent with the
Druid REST API.
native Druid connector.
:param form_data: The form-data
:param slc: The chart
:returns: The time range endpoints tuple
"""

time_range_endpoints = form_data.get("time_range_endpoints")
endpoints = form_data.get("time_range_endpoints")

if time_range_endpoints:
return time_range_endpoints
if slc and not endpoints:
try:
_, datasource_type = get_datasource_info(None, None, form_data)
except SupersetException:
return None

try:
_, datasource_type = get_datasource_info(None, None, form_data)
except SupersetException:
return None

if datasource_type == "table":
if slc:
if datasource_type == "table":
endpoints = slc.datasource.database.get_extra().get("time_range_endpoints")

if not endpoints:
endpoints = app.config["SIP_15_DEFAULT_TIME_RANGE_ENDPOINTS"]

start, end = endpoints
return (TimeRangeEndpoint(start), TimeRangeEndpoint(end))
if endpoints:
start, end = endpoints
return (TimeRangeEndpoint(start), TimeRangeEndpoint(end))

return (TimeRangeEndpoint.INCLUSIVE, TimeRangeEndpoint.EXCLUSIVE)
41 changes: 40 additions & 1 deletion tests/utils_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
import uuid
from datetime import date, datetime, time, timedelta
from decimal import Decimal
from unittest.mock import patch
from unittest.mock import Mock, patch

import numpy
from flask import Flask
Expand Down Expand Up @@ -47,10 +47,12 @@
parse_past_timedelta,
setup_cache,
split,
TimeRangeEndpoint,
validate_json,
zlib_compress,
zlib_decompress,
)
from superset.views.utils import get_time_range_endpoints


def mock_parse_human_datetime(s):
Expand Down Expand Up @@ -881,3 +883,40 @@ def test_get_or_create_db(self):
def test_get_or_create_db_invalid_uri(self):
with self.assertRaises(ArgumentError):
get_or_create_db("test_db", "yoursql:superset.db/()")

def test_get_time_range_endpoints(self):
self.assertEqual(
get_time_range_endpoints(form_data={}, slc=None),
(TimeRangeEndpoint.INCLUSIVE, TimeRangeEndpoint.EXCLUSIVE),
)

self.assertEqual(
get_time_range_endpoints(
form_data={"time_range_endpoints": ["inclusive", "inclusive"]}, slc=None
),
(TimeRangeEndpoint.INCLUSIVE, TimeRangeEndpoint.INCLUSIVE),
)

self.assertEqual(
get_time_range_endpoints(form_data={"datasource": "1_druid"}, slc=None),
(TimeRangeEndpoint.INCLUSIVE, TimeRangeEndpoint.EXCLUSIVE),
)

slc = Mock()
slc.datasource.database.get_extra.return_value = {}

self.assertEqual(
get_time_range_endpoints(form_data={"datasource": "1__table"}, slc=slc),
(TimeRangeEndpoint.UNKNOWN, TimeRangeEndpoint.INCLUSIVE),
)

slc.datasource.database.get_extra.return_value = {
"time_range_endpoints": ["inclusive", "inclusive"]
}

self.assertEqual(
get_time_range_endpoints(form_data={"datasource": "1__table"}, slc=slc),
(TimeRangeEndpoint.INCLUSIVE, TimeRangeEndpoint.INCLUSIVE),
)

self.assertIsNone(get_time_range_endpoints(form_data={}, slc=slc))

0 comments on commit 29507ba

Please sign in to comment.