Skip to content

Commit

Permalink
feat: samples with filter
Browse files Browse the repository at this point in the history
  • Loading branch information
zhaoyongjie committed Jul 14, 2022
1 parent 6b0bb80 commit b7edacf
Show file tree
Hide file tree
Showing 6 changed files with 102 additions and 25 deletions.
4 changes: 2 additions & 2 deletions superset-frontend/src/components/Chart/chartAction.js
Original file line number Diff line number Diff line change
Expand Up @@ -598,10 +598,10 @@ export function refreshChart(chartKey, force, dashboardId) {
};
}

export const getDatasetSamples = async (datasetId, force) => {
export const getDatasetSamples = async (datasetId, force, jsonPayload) => {
const endpoint = `/api/v1/dataset/${datasetId}/samples?force=${force}`;
try {
const response = await SupersetClient.get({ endpoint });
const response = await SupersetClient.post({ endpoint, jsonPayload });
return response.json.result;
} catch (err) {
const clientError = await getClientErrorObject(err);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,15 +29,15 @@ import { SamplesPane } from '../components';
import { createSamplesPaneProps } from './fixture';

describe('SamplesPane', () => {
fetchMock.get('end:/api/v1/dataset/34/samples?force=false', {
fetchMock.post('end:/api/v1/dataset/34/samples?force=false', {
result: {
data: [],
colnames: [],
coltypes: [],
},
});

fetchMock.get('end:/api/v1/dataset/35/samples?force=true', {
fetchMock.post('end:/api/v1/dataset/35/samples?force=true', {
result: {
data: [
{ __timestamp: 1230768000000, genre: 'Action' },
Expand All @@ -48,7 +48,7 @@ describe('SamplesPane', () => {
},
});

fetchMock.get('end:/api/v1/dataset/36/samples?force=false', 400);
fetchMock.post('end:/api/v1/dataset/36/samples?force=false', 400);

const setForceQuery = jest.spyOn(exploreActions, 'setForceQuery');

Expand Down
22 changes: 18 additions & 4 deletions superset/datasets/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,9 +60,11 @@
DatasetPostSchema,
DatasetPutSchema,
DatasetRelatedObjectsResponse,
DatasetSamplesQuerySchema,
get_delete_ids_schema,
get_export_ids_schema,
)
from superset.exceptions import QueryClauseValidationException
from superset.utils.core import json_int_dttm_ser, parse_boolean_string
from superset.views.base import DatasourceFilter, generate_download_headers
from superset.views.base_api import (
Expand Down Expand Up @@ -212,7 +214,10 @@ class DatasetRestApi(BaseSupersetModelRestApi):
apispec_parameter_schemas = {
"get_export_ids_schema": get_export_ids_schema,
}
openapi_spec_component_schemas = (DatasetRelatedObjectsResponse,)
openapi_spec_component_schemas = (
DatasetRelatedObjectsResponse,
DatasetSamplesQuerySchema,
)

@expose("/", methods=["POST"])
@protect()
Expand Down Expand Up @@ -764,7 +769,7 @@ def import_(self) -> Response:
command.run()
return self.response(200, message="OK")

@expose("/<pk>/samples")
@expose("/<pk>/samples", methods=["POST"])
@protect()
@safe
@statsd_metrics
Expand All @@ -775,7 +780,7 @@ def import_(self) -> Response:
def samples(self, pk: int) -> Response:
"""get samples from a Dataset
---
get:
post:
description: >-
get samples from a Dataset
parameters:
Expand All @@ -787,6 +792,13 @@ def samples(self, pk: int) -> Response:
schema:
type: boolean
name: force
requestBody:
description: Filter Schema
required: false
content:
application/json:
schema:
$ref: '#/components/schemas/DatasetSamplesQuerySchema'
responses:
200:
description: Dataset samples
Expand All @@ -810,7 +822,7 @@ def samples(self, pk: int) -> Response:
"""
try:
force = parse_boolean_string(request.args.get("force"))
rv = SamplesDatasetCommand(pk, force).run()
rv = SamplesDatasetCommand(pk, force, payload=request.json).run()
response_data = simplejson.dumps(
{"result": rv},
default=json_int_dttm_ser,
Expand All @@ -825,3 +837,5 @@ def samples(self, pk: int) -> Response:
return self.response_403()
except DatasetSamplesFailedError as ex:
return self.response_400(message=str(ex))
except QueryClauseValidationException as ex:
return self.response_400(message=str(ex))
32 changes: 23 additions & 9 deletions superset/datasets/commands/samples.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,9 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import logging
from typing import Any, Dict, Optional
from typing import Any, cast, Dict, Optional

from marshmallow import ValidationError

from superset import security_manager
from superset.commands.base import BaseCommand
Expand All @@ -30,29 +31,37 @@
DatasetSamplesFailedError,
)
from superset.datasets.dao import DatasetDAO
from superset.exceptions import SupersetSecurityException
from superset.datasets.schemas import DatasetSamplesQuerySchema
from superset.exceptions import (
QueryClauseValidationException,
SupersetSecurityException,
)
from superset.utils.core import QueryStatus

logger = logging.getLogger(__name__)


class SamplesDatasetCommand(BaseCommand):
def __init__(self, model_id: int, force: bool):
def __init__(
self,
model_id: int,
force: bool,
*,
payload: Optional[DatasetSamplesQuerySchema] = None,
):
self._model_id = model_id
self._force = force
self._model: Optional[SqlaTable] = None
self._payload = payload

def run(self) -> Dict[str, Any]:
self.validate()
if not self._model:
raise DatasetNotFoundError()
self._model = cast(SqlaTable, self._model)

qc_instance = QueryContextFactory().create(
datasource={
"type": self._model.type,
"id": self._model.id,
},
queries=[{}],
queries=[self._payload] if self._payload else [{}],
result_type=ChartDataResultType.SAMPLES,
force=self._force,
)
Expand All @@ -78,3 +87,8 @@ def validate(self) -> None:
security_manager.raise_for_ownership(self._model)
except SupersetSecurityException as ex:
raise DatasetForbiddenError() from ex

try:
self._payload = DatasetSamplesQuerySchema().load(self._payload)
except ValidationError as ex:
raise QueryClauseValidationException() from ex
12 changes: 12 additions & 0 deletions superset/datasets/schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from marshmallow.validate import Length
from marshmallow_sqlalchemy import SQLAlchemyAutoSchema

from superset.charts.schemas import ChartDataFilterSchema
from superset.datasets.models import Dataset

get_delete_ids_schema = {"type": "array", "items": {"type": "integer"}}
Expand Down Expand Up @@ -231,3 +232,14 @@ class Meta: # pylint: disable=too-few-public-methods
model = Dataset
load_instance = True
include_relationships = True


class DatasetSamplesQuerySchema(Schema):
filters = fields.List(fields.Nested(ChartDataFilterSchema), required=False)

@pre_load
# pylint: disable=no-self-use, unused-argument
def handle_none(self, data: Dict[str, Any], **kwargs: Any) -> Dict[str, Any]:
if data is None:
return {}
return data
51 changes: 44 additions & 7 deletions tests/integration_tests/datasets/api_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -1885,9 +1885,9 @@ def test_get_dataset_samples(self):

# 1. should cache data
# feeds data
self.client.get(uri)
self.client.post(uri)
# get from cache
rv = self.client.get(uri)
rv = self.client.post(uri)
rv_data = json.loads(rv.data)
assert rv.status_code == 200
assert "result" in rv_data
Expand All @@ -1898,9 +1898,9 @@ def test_get_dataset_samples(self):
# 2. should through cache
uri2 = f"api/v1/dataset/{dataset.id}/samples?force=true"
# feeds data
self.client.get(uri2)
self.client.post(uri2)
# force query
rv2 = self.client.get(uri2)
rv2 = self.client.post(uri2)
rv_data2 = json.loads(rv2.data)
assert rv_data2["result"]["cached_dttm"] is None
cache_key2 = rv_data2["result"]["cache_key"]
Expand Down Expand Up @@ -1930,7 +1930,7 @@ def test_get_dataset_samples_with_failed_cc(self):
)
uri = f"api/v1/dataset/{dataset.id}/samples"
dataset.columns.append(failed_column)
rv = self.client.get(uri)
rv = self.client.post(uri)
assert rv.status_code == 400
rv_data = json.loads(rv.data)
assert "message" in rv_data
Expand All @@ -1949,16 +1949,53 @@ def test_get_dataset_samples_on_virtual_dataset(self):

self.login(username="admin")
uri = f"api/v1/dataset/{virtual_dataset.id}/samples"
rv = self.client.get(uri)
rv = self.client.post(uri)
assert rv.status_code == 200
rv_data = json.loads(rv.data)
cache_key = rv_data["result"]["cache_key"]
assert QueryCacheManager.has(cache_key, region=CacheRegion.DATA)

# remove original column in dataset
virtual_dataset.sql = "SELECT 'foo' as foo"
rv = self.client.get(uri)
rv = self.client.post(uri)
assert rv.status_code == 400

db.session.delete(virtual_dataset)
db.session.commit()

def test_get_dataset_samples_with_filters(self):
virtual_dataset = SqlaTable(
table_name="virtual_dataset",
sql=("SELECT 'foo' as foo, 'bar' as bar UNION ALL SELECT 'foo2', 'bar2'"),
database=get_example_database(),
)
TableColumn(column_name="foo", type="VARCHAR(255)", table=virtual_dataset)
TableColumn(column_name="bar", type="VARCHAR(255)", table=virtual_dataset)
SqlMetric(metric_name="count", expression="count(*)", table=virtual_dataset)

self.login(username="admin")
uri = f"api/v1/dataset/{virtual_dataset.id}/samples"
rv = self.client.post(uri, json=None)
assert rv.status_code == 200

rv = self.client.post(uri, json={})
assert rv.status_code == 200

rv = self.client.post(uri, json={"foo": "bar"})
assert rv.status_code == 400

rv = self.client.post(
uri, json={"filters": [{"col": "foo", "op": "INVALID", "val": "foo2"}]}
)
assert rv.status_code == 400

rv = self.client.post(
uri, json={"filters": [{"col": "foo", "op": "==", "val": "foo2"}]}
)
assert rv.status_code == 200
rv_data = json.loads(rv.data)
assert rv_data["result"]["colnames"] == ["foo", "bar"]
assert rv_data["result"]["rowcount"] == 1

db.session.delete(virtual_dataset)
db.session.commit()

0 comments on commit b7edacf

Please sign in to comment.