Skip to content

Commit

Permalink
feat: Add new result formats and types to chart data API (apache#9841)
Browse files Browse the repository at this point in the history
* feat: Add new result formats and types to chart data API

* lint

* Linting

* Add language to query payload

* Fix tests

* simplify tests
  • Loading branch information
villebro authored and auxten committed Nov 20, 2020
1 parent aa0cef7 commit c371c96
Show file tree
Hide file tree
Showing 6 changed files with 132 additions and 23 deletions.
Expand Up @@ -27,7 +27,7 @@ import ExploreChartPanel from './ExploreChartPanel';
import ControlPanelsContainer from './ControlPanelsContainer';
import SaveModal from './SaveModal';
import QueryAndSaveBtns from './QueryAndSaveBtns';
import { getExploreUrl, getExploreLongUrl } from '../exploreUtils';
import { getExploreLongUrl } from '../exploreUtils';
import { areObjectsEqual } from '../../reduxUtils';
import { getFormDataFromControls } from '../controlUtils';
import { chartPropShape } from '../../dashboard/util/propShapes';
Expand Down
8 changes: 8 additions & 0 deletions superset/charts/schemas.py
Expand Up @@ -711,6 +711,14 @@ class ChartDataQueryContextSchema(Schema):
description="Should the queries be forced to load from the source. "
"Default: `false`",
)
result_type = fields.String(
description="Type of results to return",
validate=validate.OneOf(choices=("query", "results", "samples")),
)
result_format = fields.String(
description="Format of result payload",
validate=validate.OneOf(choices=("json", "csv")),
)

# pylint: disable=no-self-use
@post_load
Expand Down
37 changes: 32 additions & 5 deletions superset/common/query_context.py
Expand Up @@ -14,10 +14,11 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import copy
import logging
import pickle as pkl
from datetime import datetime, timedelta
from typing import Any, ClassVar, Dict, List, Optional
from typing import Any, ClassVar, Dict, List, Optional, Union

import numpy as np
import pandas as pd
Expand Down Expand Up @@ -49,22 +50,28 @@ class QueryContext:
queries: List[QueryObject]
force: bool
custom_cache_timeout: Optional[int]
response_type: utils.ChartDataResponseType
response_format: utils.ChartDataResponseFormat

# TODO: Type datasource and query_object dictionary with TypedDict when it becomes
# a vanilla python type https://github.com/python/mypy/issues/5288
def __init__(
def __init__( # pylint: disable=too-many-arguments
self,
datasource: Dict[str, Any],
queries: List[Dict[str, Any]],
force: bool = False,
custom_cache_timeout: Optional[int] = None,
response_format: Optional[utils.ChartDataResponseFormat] = None,
response_type: Optional[utils.ChartDataResponseType] = None,
) -> None:
self.datasource = ConnectorRegistry.get_datasource(
str(datasource["type"]), int(datasource["id"]), db.session
)
self.queries = [QueryObject(**query_obj) for query_obj in queries]
self.force = force
self.custom_cache_timeout = custom_cache_timeout
self.response_format = response_format or utils.ChartDataResponseFormat.JSON
self.response_type = response_type or utils.ChartDataResponseType.RESULTS

def get_query_result(self, query_object: QueryObject) -> Dict[str, Any]:
"""Returns a pandas dataframe based on the query object"""
Expand Down Expand Up @@ -124,12 +131,32 @@ def df_metrics_to_num( # pylint: disable=no-self-use
if dtype.type == np.object_ and col in query_object.metrics:
df[col] = pd.to_numeric(df[col], errors="coerce")

@staticmethod
def get_data(df: pd.DataFrame,) -> List[Dict]: # pylint: disable=no-self-use
def get_data(
self, df: pd.DataFrame,
) -> Union[str, List[Dict[str, Any]]]: # pylint: disable=no-self-use
if self.response_format == utils.ChartDataResponseFormat.CSV:
include_index = not isinstance(df.index, pd.RangeIndex)
result = df.to_csv(index=include_index, **config["CSV_EXPORT"])
return result or ""

return df.to_dict(orient="records")

def get_single_payload(self, query_obj: QueryObject) -> Dict[str, Any]:
"""Returns a payload of metadata and data"""
if self.response_type == utils.ChartDataResponseType.QUERY:
return {
"query": self.datasource.get_query_str(query_obj.to_dict()),
"language": self.datasource.query_language,
}
if self.response_type == utils.ChartDataResponseType.SAMPLES:
row_limit = query_obj.row_limit or 1000
query_obj = copy.copy(query_obj)
query_obj.groupby = []
query_obj.metrics = []
query_obj.post_processing = []
query_obj.row_limit = row_limit
query_obj.columns = [o.column_name for o in self.datasource.columns]

payload = self.get_df_payload(query_obj)
df = payload["df"]
status = payload["status"]
Expand All @@ -142,7 +169,7 @@ def get_single_payload(self, query_obj: QueryObject) -> Dict[str, Any]:
return payload

def get_payload(self) -> List[Dict[str, Any]]:
"""Get all the payloads from the arrays"""
"""Get all the payloads from the QueryObjects"""
return [self.get_single_payload(query_object) for query_object in self.queries]

@property
Expand Down
19 changes: 19 additions & 0 deletions superset/utils/core.py
Expand Up @@ -1367,3 +1367,22 @@ class FilterOperator(str, Enum):
IN = "IN"
NOT_IN = "NOT IN"
REGEX = "REGEX"


class ChartDataResponseType(str, Enum):
"""
Chart data response type
"""

QUERY = "query"
RESULTS = "results"
SAMPLES = "samples"


class ChartDataResponseFormat(str, Enum):
"""
Chart data response format
"""

CSV = "csv"
JSON = "json"
31 changes: 15 additions & 16 deletions superset/views/core.py
Expand Up @@ -636,24 +636,22 @@ def get_raw_results(self, viz_obj):
def get_samples(self, viz_obj):
return self.json_response({"data": viz_obj.get_samples()})

def generate_json(
self, viz_obj, csv=False, query=False, results=False, samples=False
):
if csv:
def generate_json(self, viz_obj, response_type: Optional[str] = None) -> Response:
if response_type == utils.ChartDataResponseFormat.CSV:
return CsvResponse(
viz_obj.get_csv(),
status=200,
headers=generate_download_headers("csv"),
mimetype="application/csv",
)

if query:
if response_type == utils.ChartDataResponseType.QUERY:
return self.get_query_string_response(viz_obj)

if results:
if response_type == utils.ChartDataResponseType.RESULTS:
return self.get_raw_results(viz_obj)

if samples:
if response_type == utils.ChartDataResponseType.SAMPLES:
return self.get_samples(viz_obj)

payload = viz_obj.get_payload()
Expand Down Expand Up @@ -715,11 +713,14 @@ def explore_json(self, datasource_type=None, datasource_id=None):
payloads based on the request args in the first block
TODO: break into one endpoint for each return shape"""
csv = request.args.get("csv") == "true"
query = request.args.get("query") == "true"
results = request.args.get("results") == "true"
samples = request.args.get("samples") == "true"
force = request.args.get("force") == "true"
response_type = utils.ChartDataResponseFormat.JSON.value
responses = [resp_format for resp_format in utils.ChartDataResponseFormat]
responses.extend([resp_type for resp_type in utils.ChartDataResponseType])
for response_option in responses:
if request.args.get(response_option) == "true":
response_type = response_option
break

form_data = get_form_data()[0]

try:
Expand All @@ -731,12 +732,10 @@ def explore_json(self, datasource_type=None, datasource_id=None):
datasource_type=datasource_type,
datasource_id=datasource_id,
form_data=form_data,
force=force,
force=request.args.get("force") == "true",
)

return self.generate_json(
viz_obj, csv=csv, query=query, results=results, samples=samples
)
return self.generate_json(viz_obj, response_type)
except SupersetException as ex:
return json_error_response(utils.error_msg_from_exception(ex))

Expand Down
58 changes: 57 additions & 1 deletion tests/query_context_tests.py
Expand Up @@ -19,7 +19,11 @@
from superset.charts.schemas import ChartDataQueryContextSchema
from superset.common.query_context import QueryContext
from superset.connectors.connector_registry import ConnectorRegistry
from superset.utils.core import TimeRangeEndpoint
from superset.utils.core import (
ChartDataResponseFormat,
ChartDataResponseType,
TimeRangeEndpoint,
)
from tests.base_tests import SupersetTestCase
from tests.fixtures.query_context import get_query_context

Expand Down Expand Up @@ -131,3 +135,55 @@ def test_convert_deprecated_fields(self):
query_object = query_context.queries[0]
self.assertEqual(query_object.granularity, "timecol")
self.assertIn("having_druid", query_object.extras)

def test_csv_response_format(self):
"""
Ensure that CSV result format works
"""
self.login(username="admin")
table_name = "birth_names"
table = self.get_table_by_name(table_name)
payload = get_query_context(table.name, table.id, table.type)
payload["response_format"] = ChartDataResponseFormat.CSV.value
payload["queries"][0]["row_limit"] = 10
query_context = QueryContext(**payload)
responses = query_context.get_payload()
self.assertEqual(len(responses), 1)
data = responses[0]["data"]
self.assertIn("name,sum__num\n", data)
self.assertEqual(len(data.split("\n")), 12)

def test_samples_response_type(self):
"""
Ensure that samples result type works
"""
self.login(username="admin")
table_name = "birth_names"
table = self.get_table_by_name(table_name)
payload = get_query_context(table.name, table.id, table.type)
payload["response_type"] = ChartDataResponseType.SAMPLES.value
payload["queries"][0]["row_limit"] = 5
query_context = QueryContext(**payload)
responses = query_context.get_payload()
self.assertEqual(len(responses), 1)
data = responses[0]["data"]
self.assertIsInstance(data, list)
self.assertEqual(len(data), 5)
self.assertNotIn("sum__num", data[0])

def test_query_response_type(self):
"""
Ensure that query result type works
"""
self.login(username="admin")
table_name = "birth_names"
table = self.get_table_by_name(table_name)
payload = get_query_context(table.name, table.id, table.type)
payload["response_type"] = ChartDataResponseType.QUERY.value
query_context = QueryContext(**payload)
responses = query_context.get_payload()
self.assertEqual(len(responses), 1)
response = responses[0]
self.assertEqual(len(response), 2)
self.assertEqual(response["language"], "sql")
self.assertIn("SELECT", response["query"])

0 comments on commit c371c96

Please sign in to comment.