Skip to content

Commit

Permalink
feat: add support for query offset (#10010)
Browse files Browse the repository at this point in the history
* feat: add support for query offset

* Address comments and add new tests
  • Loading branch information
villebro committed Jun 9, 2020
1 parent 2a3305e commit 315518d
Show file tree
Hide file tree
Showing 7 changed files with 165 additions and 23 deletions.
12 changes: 11 additions & 1 deletion superset/charts/schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,9 @@
# under the License.
from typing import Any, Dict, Union

from flask_babel import gettext as _
from marshmallow import fields, post_load, Schema, validate, ValidationError
from marshmallow.validate import Length
from marshmallow.validate import Length, Range

from superset.common.query_context import QueryContext
from superset.exceptions import SupersetException
Expand Down Expand Up @@ -663,6 +664,15 @@ class ChartDataQueryObjectSchema(Schema):
)
row_limit = fields.Integer(
description='Maximum row count. Default: `config["ROW_LIMIT"]`',
validate=[
Range(min=1, error=_("`row_limit` must be greater than or equal to 1"))
],
)
row_offset = fields.Integer(
description="Number of rows to skip. Default: `0`",
validate=[
Range(min=0, error=_("`row_offset` must be greater than or equal to 0"))
],
)
order_desc = fields.Boolean(
description="Reverse order. Default: `false`", required=False
Expand Down
4 changes: 2 additions & 2 deletions superset/common/query_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,14 +25,13 @@
import pandas as pd

from superset import app, cache, db, security_manager
from superset.common.query_object import QueryObject
from superset.connectors.base.models import BaseDatasource
from superset.connectors.connector_registry import ConnectorRegistry
from superset.stats_logger import BaseStatsLogger
from superset.utils import core as utils
from superset.utils.core import DTTM_ALIAS

from .query_object import QueryObject

config = app.config
stats_logger: BaseStatsLogger = config["STATS_LOGGER"]
logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -156,6 +155,7 @@ def get_single_payload(self, query_obj: QueryObject) -> Dict[str, Any]:
query_obj.metrics = []
query_obj.post_processing = []
query_obj.row_limit = min(row_limit, config["SAMPLES_ROW_LIMIT"])
query_obj.row_offset = 0
query_obj.columns = [o.column_name for o in self.datasource.columns]
payload = self.get_df_payload(query_obj)
df = payload["df"]
Expand Down
15 changes: 10 additions & 5 deletions superset/common/query_object.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
from superset.utils import core as utils, pandas_postprocessing
from superset.views.utils import get_time_range_endpoints

config = app.config
logger = logging.getLogger(__name__)

# TODO: Type Metrics dictionary with TypedDict when it becomes a vanilla python type
Expand Down Expand Up @@ -66,6 +67,7 @@ class QueryObject:
groupby: List[str]
metrics: List[Union[Dict[str, Any], str]]
row_limit: int
row_offset: int
filter: List[Dict[str, Any]]
timeseries_limit: int
timeseries_limit_metric: Optional[Metric]
Expand All @@ -85,7 +87,8 @@ def __init__(
time_shift: Optional[str] = None,
is_timeseries: bool = False,
timeseries_limit: int = 0,
row_limit: int = app.config["ROW_LIMIT"],
row_limit: Optional[int] = None,
row_offset: Optional[int] = None,
timeseries_limit_metric: Optional[Metric] = None,
order_desc: bool = True,
extras: Optional[Dict[str, Any]] = None,
Expand All @@ -100,10 +103,10 @@ def __init__(
self.granularity = granularity
self.from_dttm, self.to_dttm = utils.get_since_until(
relative_start=extras.get(
"relative_start", app.config["DEFAULT_RELATIVE_START_TIME"]
"relative_start", config["DEFAULT_RELATIVE_START_TIME"]
),
relative_end=extras.get(
"relative_end", app.config["DEFAULT_RELATIVE_END_TIME"]
"relative_end", config["DEFAULT_RELATIVE_END_TIME"]
),
time_range=time_range,
time_shift=time_shift,
Expand All @@ -123,14 +126,15 @@ def __init__(
for metric in metrics
]

self.row_limit = row_limit
self.row_limit = row_limit or config["ROW_LIMIT"]
self.row_offset = row_offset or 0
self.filter = filters or []
self.timeseries_limit = timeseries_limit
self.timeseries_limit_metric = timeseries_limit_metric
self.order_desc = order_desc
self.extras = extras

if app.config["SIP_15_ENABLED"] and "time_range_endpoints" not in self.extras:
if config["SIP_15_ENABLED"] and "time_range_endpoints" not in self.extras:
self.extras["time_range_endpoints"] = get_time_range_endpoints(form_data={})

self.columns = columns or []
Expand Down Expand Up @@ -184,6 +188,7 @@ def to_dict(self) -> Dict[str, Any]:
"is_timeseries": self.is_timeseries,
"metrics": self.metrics,
"row_limit": self.row_limit,
"row_offset": self.row_offset,
"filter": self.filter,
"timeseries_limit": self.timeseries_limit,
"timeseries_limit_metric": self.timeseries_limit_metric,
Expand Down
3 changes: 3 additions & 0 deletions superset/connectors/druid/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -1179,6 +1179,7 @@ def run_query( # druid
timeseries_limit: Optional[int] = None,
timeseries_limit_metric: Optional[Metric] = None,
row_limit: Optional[int] = None,
row_offset: Optional[int] = None,
inner_from_dttm: Optional[datetime] = None,
inner_to_dttm: Optional[datetime] = None,
orderby: Optional[Any] = None,
Expand All @@ -1192,6 +1193,8 @@ def run_query( # druid
# TODO refactor into using a TBD Query object
client = client or self.cluster.get_pydruid_client()
row_limit = row_limit or conf.get("ROW_LIMIT")
if row_offset:
raise SupersetException("Offset not implemented for Druid connector")

if not is_timeseries:
granularity = "all"
Expand Down
4 changes: 4 additions & 0 deletions superset/connectors/sqla/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -741,6 +741,7 @@ def get_sqla_query( # sqla
timeseries_limit: int = 15,
timeseries_limit_metric: Optional[Metric] = None,
row_limit: Optional[int] = None,
row_offset: Optional[int] = None,
inner_from_dttm: Optional[datetime] = None,
inner_to_dttm: Optional[datetime] = None,
orderby: Optional[List[Tuple[ColumnElement, bool]]] = None,
Expand All @@ -753,6 +754,7 @@ def get_sqla_query( # sqla
"groupby": groupby,
"metrics": metrics,
"row_limit": row_limit,
"row_offset": row_offset,
"to_dttm": to_dttm,
"filter": filter,
"columns": {col.column_name: col for col in self.columns},
Expand Down Expand Up @@ -967,6 +969,8 @@ def get_sqla_query( # sqla

if row_limit:
qry = qry.limit(row_limit)
if row_offset:
qry = qry.offset(row_offset)

if (
is_timeseries
Expand Down
89 changes: 74 additions & 15 deletions tests/charts/api_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,22 +14,27 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# isort:skip_file
"""Unit tests for Superset"""
import json
from typing import List, Optional
from unittest import mock

import prison
from sqlalchemy.sql import func

import tests.test_app
from tests.test_app import app
from superset.connectors.connector_registry import ConnectorRegistry
from superset.extensions import db, security_manager
from superset.models.dashboard import Dashboard
from superset.models.slice import Slice
from superset.utils import core as utils
from tests.base_api_tests import ApiOwnersTestCaseMixin
from tests.base_tests import SupersetTestCase
from tests.fixtures.query_context import get_query_context

CHART_DATA_URI = "api/v1/chart/data"


class ChartApiTests(SupersetTestCase, ApiOwnersTestCaseMixin):
resource_name = "chart"
Expand Down Expand Up @@ -634,32 +639,88 @@ def test_get_charts_no_data_access(self):
data = json.loads(rv.data.decode("utf-8"))
self.assertEqual(data["count"], 0)

def test_chart_data(self):
def test_chart_data_simple(self):
"""
Query API: Test chart data query
Chart data API: Test chart data query
"""
self.login(username="admin")
table = self.get_table_by_name("birth_names")
payload = get_query_context(table.name, table.id, table.type)
uri = "api/v1/chart/data"
rv = self.post_assert_metric(uri, payload, "data")
request_payload = get_query_context(table.name, table.id, table.type)
rv = self.post_assert_metric(CHART_DATA_URI, request_payload, "data")
self.assertEqual(rv.status_code, 200)
data = json.loads(rv.data.decode("utf-8"))
self.assertEqual(data["result"][0]["rowcount"], 100)

def test_chart_data_limit_offset(self):
"""
Chart data API: Test chart data query with limit and offset
"""
self.login(username="admin")
table = self.get_table_by_name("birth_names")
request_payload = get_query_context(table.name, table.id, table.type)
request_payload["queries"][0]["row_limit"] = 5
request_payload["queries"][0]["row_offset"] = 0
request_payload["queries"][0]["orderby"] = [["name", True]]
rv = self.post_assert_metric(CHART_DATA_URI, request_payload, "data")
response_payload = json.loads(rv.data.decode("utf-8"))
result = response_payload["result"][0]
self.assertEqual(result["rowcount"], 5)

# ensure that offset works properly
offset = 2
expected_name = result["data"][offset]["name"]
request_payload["queries"][0]["row_offset"] = offset
rv = self.post_assert_metric(CHART_DATA_URI, request_payload, "data")
response_payload = json.loads(rv.data.decode("utf-8"))
result = response_payload["result"][0]
self.assertEqual(result["rowcount"], 5)
self.assertEqual(result["data"][0]["name"], expected_name)

@mock.patch(
"superset.common.query_object.config", {**app.config, "ROW_LIMIT": 7},
)
def test_chart_data_default_row_limit(self):
"""
Chart data API: Ensure row count doesn't exceed default limit
"""
self.login(username="admin")
table = self.get_table_by_name("birth_names")
request_payload = get_query_context(table.name, table.id, table.type)
del request_payload["queries"][0]["row_limit"]
rv = self.post_assert_metric(CHART_DATA_URI, request_payload, "data")
response_payload = json.loads(rv.data.decode("utf-8"))
result = response_payload["result"][0]
self.assertEqual(result["rowcount"], 7)

@mock.patch(
"superset.common.query_context.config", {**app.config, "SAMPLES_ROW_LIMIT": 5},
)
def test_chart_data_default_sample_limit(self):
"""
Chart data API: Ensure sample response row count doesn't exceed default limit
"""
self.login(username="admin")
table = self.get_table_by_name("birth_names")
request_payload = get_query_context(table.name, table.id, table.type)
request_payload["result_type"] = utils.ChartDataResultType.SAMPLES
request_payload["queries"][0]["row_limit"] = 10
rv = self.post_assert_metric(CHART_DATA_URI, request_payload, "data")
response_payload = json.loads(rv.data.decode("utf-8"))
result = response_payload["result"][0]
self.assertEqual(result["rowcount"], 5)

def test_chart_data_with_invalid_datasource(self):
"""Query API: Test chart data query with invalid schema
"""Chart data API: Test chart data query with invalid schema
"""
self.login(username="admin")
table = self.get_table_by_name("birth_names")
payload = get_query_context(table.name, table.id, table.type)
payload["datasource"] = "abc"
uri = "api/v1/chart/data"
rv = self.post_assert_metric(uri, payload, "data")
rv = self.post_assert_metric(CHART_DATA_URI, payload, "data")
self.assertEqual(rv.status_code, 400)

def test_chart_data_with_invalid_enum_value(self):
"""Query API: Test chart data query with invalid enum value
"""Chart data API: Test chart data query with invalid enum value
"""
self.login(username="admin")
table = self.get_table_by_name("birth_names")
Expand All @@ -668,19 +729,17 @@ def test_chart_data_with_invalid_enum_value(self):
"abc",
"EXCLUSIVE",
]
uri = "api/v1/chart/data"
rv = self.client.post(uri, json=payload)
rv = self.client.post(CHART_DATA_URI, json=payload)
self.assertEqual(rv.status_code, 400)

def test_query_exec_not_allowed(self):
"""
Query API: Test chart data query not allowed
Chart data API: Test chart data query not allowed
"""
self.login(username="gamma")
table = self.get_table_by_name("birth_names")
payload = get_query_context(table.name, table.id, table.type)
uri = "api/v1/chart/data"
rv = self.post_assert_metric(uri, payload, "data")
rv = self.post_assert_metric(CHART_DATA_URI, payload, "data")
self.assertEqual(rv.status_code, 401)

def test_datasources(self):
Expand Down
61 changes: 61 additions & 0 deletions tests/charts/schema_tests.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Unit tests for Superset"""
from typing import Any, Dict, Tuple

from superset.charts.schemas import ChartDataQueryContextSchema
from superset.common.query_context import QueryContext
from tests.base_tests import SupersetTestCase
from tests.fixtures.query_context import get_query_context
from tests.test_app import app


def load_query_context(payload: Dict[str, Any]) -> Tuple[QueryContext, Dict[str, Any]]:
return ChartDataQueryContextSchema().load(payload)


class SchemaTestCase(SupersetTestCase):
def test_query_context_limit_and_offset(self):
self.login(username="admin")
table_name = "birth_names"
table = self.get_table_by_name(table_name)
payload = get_query_context(table.name, table.id, table.type)

# Use defaults
payload["queries"][0].pop("row_limit", None)
payload["queries"][0].pop("row_offset", None)
query_context, errors = load_query_context(payload)
self.assertEqual(errors, {})
query_object = query_context.queries[0]
self.assertEqual(query_object.row_limit, app.config["ROW_LIMIT"])
self.assertEqual(query_object.row_offset, 0)

# Valid limit and offset
payload["queries"][0]["row_limit"] = 100
payload["queries"][0]["row_offset"] = 200
query_context, errors = ChartDataQueryContextSchema().load(payload)
self.assertEqual(errors, {})
query_object = query_context.queries[0]
self.assertEqual(query_object.row_limit, 100)
self.assertEqual(query_object.row_offset, 200)

# too low limit and offset
payload["queries"][0]["row_limit"] = 0
payload["queries"][0]["row_offset"] = -1
query_context, errors = ChartDataQueryContextSchema().load(payload)
self.assertIn("row_limit", errors["queries"][0])
self.assertIn("row_offset", errors["queries"][0])

0 comments on commit 315518d

Please sign in to comment.