Skip to content

Commit

Permalink
feat: supports mulitple filters in samples endpoint (#21008)
Browse files Browse the repository at this point in the history
  • Loading branch information
zhaoyongjie committed Aug 8, 2022
1 parent e214e1a commit 802b69f
Show file tree
Hide file tree
Showing 6 changed files with 154 additions and 60 deletions.
1 change: 1 addition & 0 deletions superset/common/chart_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,3 +38,4 @@ class ChartDataResultType(str, Enum):
SAMPLES = "samples"
TIMEGRAINS = "timegrains"
POST_PROCESSED = "post_processed"
DRILL_DETAIL = "drill_detail"
22 changes: 22 additions & 0 deletions superset/common/query_actions.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,27 @@ def _get_samples(
return _get_full(query_context, query_obj, force_cached)


def _get_drill_detail(
query_context: QueryContext, query_obj: QueryObject, force_cached: bool = False
) -> Dict[str, Any]:
# todo(yongjie): Remove this function,
# when determining whether samples should be applied to the time filter.
datasource = _get_datasource(query_context, query_obj)
query_obj = copy.copy(query_obj)
query_obj.is_timeseries = False
query_obj.orderby = []
query_obj.metrics = None
query_obj.post_processing = []
qry_obj_cols = []
for o in datasource.columns:
if isinstance(o, dict):
qry_obj_cols.append(o.get("column_name"))
else:
qry_obj_cols.append(o.column_name)
query_obj.columns = qry_obj_cols
return _get_full(query_context, query_obj, force_cached)


def _get_results(
query_context: QueryContext, query_obj: QueryObject, force_cached: bool = False
) -> Dict[str, Any]:
Expand All @@ -182,6 +203,7 @@ def _get_results(
# and post-process it later where we have the chart context, since
# post-processing is unique to each visualization type
ChartDataResultType.POST_PROCESSED: _get_full,
ChartDataResultType.DRILL_DETAIL: _get_drill_detail,
}


Expand Down
13 changes: 12 additions & 1 deletion superset/views/datasource/schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from typing_extensions import TypedDict

from superset import app
from superset.charts.schemas import ChartDataFilterSchema
from superset.charts.schemas import ChartDataExtrasSchema, ChartDataFilterSchema
from superset.utils.core import DatasourceType


Expand Down Expand Up @@ -62,6 +62,17 @@ def normalize(

class SamplesPayloadSchema(Schema):
filters = fields.List(fields.Nested(ChartDataFilterSchema), required=False)
granularity = fields.String(
allow_none=True,
)
time_range = fields.String(
allow_none=True,
)
extras = fields.Nested(
ChartDataExtrasSchema,
description="Extra parameters to add to the query.",
allow_none=True,
)

@pre_load
# pylint: disable=no-self-use, unused-argument
Expand Down
35 changes: 24 additions & 11 deletions superset/views/datasource/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,17 +60,30 @@ def get_samples( # pylint: disable=too-many-arguments,too-many-locals
limit_clause = get_limit_clause(page, per_page)

# todo(yongjie): Constructing count(*) and samples in the same query_context,
# then remove query_type==SAMPLES
# constructing samples query
samples_instance = QueryContextFactory().create(
datasource={
"type": datasource.type,
"id": datasource.id,
},
queries=[{**payload, **limit_clause} if payload else limit_clause],
result_type=ChartDataResultType.SAMPLES,
force=force,
)
if payload is None:
# constructing samples query
samples_instance = QueryContextFactory().create(
datasource={
"type": datasource.type,
"id": datasource.id,
},
queries=[limit_clause],
result_type=ChartDataResultType.SAMPLES,
force=force,
)
else:
# constructing drill detail query
# When query_type == 'samples' the `time filter` will be removed,
# so it is not applicable drill detail query
samples_instance = QueryContextFactory().create(
datasource={
"type": datasource.type,
"id": datasource.id,
},
queries=[{**payload, **limit_clause}],
result_type=ChartDataResultType.DRILL_DETAIL,
force=force,
)

# constructing count(*) query
count_star_metric = {
Expand Down
11 changes: 6 additions & 5 deletions tests/integration_tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -314,7 +314,7 @@ def physical_dataset():
col2 VARCHAR(255),
col3 DECIMAL(4,2),
col4 VARCHAR(255),
col5 VARCHAR(255)
col5 TIMESTAMP
);
"""
)
Expand Down Expand Up @@ -342,11 +342,10 @@ def physical_dataset():
TableColumn(column_name="col2", type="VARCHAR(255)", table=dataset)
TableColumn(column_name="col3", type="DECIMAL(4,2)", table=dataset)
TableColumn(column_name="col4", type="VARCHAR(255)", table=dataset)
TableColumn(column_name="col5", type="VARCHAR(255)", table=dataset)
TableColumn(column_name="col5", type="TIMESTAMP", is_dttm=True, table=dataset)
SqlMetric(metric_name="count", expression="count(*)", table=dataset)
db.session.merge(dataset)
if example_database.backend == "sqlite":
db.session.commit()
db.session.commit()

yield dataset

Expand All @@ -355,5 +354,7 @@ def physical_dataset():
DROP TABLE physical_dataset;
"""
)
db.session.delete(dataset)
dataset = db.session.query(SqlaTable).filter_by(table_name="physical_dataset").all()
for ds in dataset:
db.session.delete(ds)
db.session.commit()
132 changes: 89 additions & 43 deletions tests/integration_tests/datasource_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -432,34 +432,32 @@ def test_get_samples(test_client, login_as_admin, virtual_dataset):
test_client.post(uri)
# get from cache
rv = test_client.post(uri)
rv_data = json.loads(rv.data)
assert rv.status_code == 200
assert len(rv_data["result"]["data"]) == 10
assert len(rv.json["result"]["data"]) == 10
assert QueryCacheManager.has(
rv_data["result"]["cache_key"],
rv.json["result"]["cache_key"],
region=CacheRegion.DATA,
)
assert rv_data["result"]["is_cached"]
assert rv.json["result"]["is_cached"]

# 2. should read through cache data
uri2 = f"/datasource/samples?datasource_id={virtual_dataset.id}&datasource_type=table&force=true"
# feeds data
test_client.post(uri2)
# force query
rv2 = test_client.post(uri2)
rv_data2 = json.loads(rv2.data)
assert rv2.status_code == 200
assert len(rv_data2["result"]["data"]) == 10
assert len(rv2.json["result"]["data"]) == 10
assert QueryCacheManager.has(
rv_data2["result"]["cache_key"],
rv2.json["result"]["cache_key"],
region=CacheRegion.DATA,
)
assert not rv_data2["result"]["is_cached"]
assert not rv2.json["result"]["is_cached"]

# 3. data precision
assert "colnames" in rv_data2["result"]
assert "coltypes" in rv_data2["result"]
assert "data" in rv_data2["result"]
assert "colnames" in rv2.json["result"]
assert "coltypes" in rv2.json["result"]
assert "data" in rv2.json["result"]

eager_samples = virtual_dataset.database.get_df(
f"select * from ({virtual_dataset.sql}) as tbl"
Expand All @@ -468,7 +466,7 @@ def test_get_samples(test_client, login_as_admin, virtual_dataset):
# the col3 is Decimal
eager_samples["col3"] = eager_samples["col3"].apply(float)
eager_samples = eager_samples.to_dict(orient="records")
assert eager_samples == rv_data2["result"]["data"]
assert eager_samples == rv2.json["result"]["data"]


def test_get_samples_with_incorrect_cc(test_client, login_as_admin, virtual_dataset):
Expand All @@ -486,10 +484,9 @@ def test_get_samples_with_incorrect_cc(test_client, login_as_admin, virtual_data
rv = test_client.post(uri)
assert rv.status_code == 422

rv_data = json.loads(rv.data)
assert "error" in rv_data
assert "error" in rv.json
if virtual_dataset.database.db_engine_spec.engine_name == "PostgreSQL":
assert "INCORRECT SQL" in rv_data.get("error")
assert "INCORRECT SQL" in rv.json.get("error")


def test_get_samples_on_physical_dataset(test_client, login_as_admin, physical_dataset):
Expand All @@ -498,11 +495,10 @@ def test_get_samples_on_physical_dataset(test_client, login_as_admin, physical_d
)
rv = test_client.post(uri)
assert rv.status_code == 200
rv_data = json.loads(rv.data)
assert QueryCacheManager.has(
rv_data["result"]["cache_key"], region=CacheRegion.DATA
rv.json["result"]["cache_key"], region=CacheRegion.DATA
)
assert len(rv_data["result"]["data"]) == 10
assert len(rv.json["result"]["data"]) == 10


def test_get_samples_with_filters(test_client, login_as_admin, virtual_dataset):
Expand Down Expand Up @@ -533,9 +529,8 @@ def test_get_samples_with_filters(test_client, login_as_admin, virtual_dataset):
},
)
assert rv.status_code == 200
rv_data = json.loads(rv.data)
assert rv_data["result"]["colnames"] == ["col1", "col2", "col3", "col4", "col5"]
assert rv_data["result"]["rowcount"] == 1
assert rv.json["result"]["colnames"] == ["col1", "col2", "col3", "col4", "col5"]
assert rv.json["result"]["rowcount"] == 1

# empty results
rv = test_client.post(
Expand All @@ -547,9 +542,64 @@ def test_get_samples_with_filters(test_client, login_as_admin, virtual_dataset):
},
)
assert rv.status_code == 200
rv_data = json.loads(rv.data)
assert rv_data["result"]["colnames"] == []
assert rv_data["result"]["rowcount"] == 0
assert rv.json["result"]["colnames"] == []
assert rv.json["result"]["rowcount"] == 0


def test_get_samples_with_time_filter(test_client, login_as_admin, physical_dataset):
uri = (
f"/datasource/samples?datasource_id={physical_dataset.id}&datasource_type=table"
)
payload = {
"granularity": "col5",
"time_range": "2000-01-02 : 2000-01-04",
}
rv = test_client.post(uri, json=payload)
assert len(rv.json["result"]["data"]) == 2
if physical_dataset.database.backend != "sqlite":
assert [row["col5"] for row in rv.json["result"]["data"]] == [
946771200000.0, # 2000-01-02 00:00:00
946857600000.0, # 2000-01-03 00:00:00
]
assert rv.json["result"]["page"] == 1
assert rv.json["result"]["per_page"] == app.config["SAMPLES_ROW_LIMIT"]
assert rv.json["result"]["total_count"] == 2


def test_get_samples_with_multiple_filters(
test_client, login_as_admin, physical_dataset
):
# 1. empty response
uri = (
f"/datasource/samples?datasource_id={physical_dataset.id}&datasource_type=table"
)
payload = {
"granularity": "col5",
"time_range": "2000-01-02 : 2000-01-04",
"filters": [
{"col": "col4", "op": "IS NOT NULL"},
],
}
rv = test_client.post(uri, json=payload)
assert len(rv.json["result"]["data"]) == 0

# 2. adhoc filters, time filters, and custom where
payload = {
"granularity": "col5",
"time_range": "2000-01-02 : 2000-01-04",
"filters": [
{"col": "col2", "op": "==", "val": "c"},
],
"extras": {"where": "col3 = 1.2 and col4 is null"},
}
rv = test_client.post(uri, json=payload)
assert len(rv.json["result"]["data"]) == 1
assert rv.json["result"]["total_count"] == 1
assert "2000-01-02" in rv.json["result"]["query"]
assert "2000-01-04" in rv.json["result"]["query"]
assert "col3 = 1.2" in rv.json["result"]["query"]
assert "col4 is null" in rv.json["result"]["query"]
assert "col2 = 'c'" in rv.json["result"]["query"]


def test_get_samples_pagination(test_client, login_as_admin, virtual_dataset):
Expand All @@ -558,10 +608,9 @@ def test_get_samples_pagination(test_client, login_as_admin, virtual_dataset):
f"/datasource/samples?datasource_id={virtual_dataset.id}&datasource_type=table"
)
rv = test_client.post(uri)
rv_data = json.loads(rv.data)
assert rv_data["result"]["page"] == 1
assert rv_data["result"]["per_page"] == app.config["SAMPLES_ROW_LIMIT"]
assert rv_data["result"]["total_count"] == 10
assert rv.json["result"]["page"] == 1
assert rv.json["result"]["per_page"] == app.config["SAMPLES_ROW_LIMIT"]
assert rv.json["result"]["total_count"] == 10

# 2. incorrect per_page
per_pages = (app.config["SAMPLES_ROW_LIMIT"] + 1, 0, "xx")
Expand All @@ -582,25 +631,22 @@ def test_get_samples_pagination(test_client, login_as_admin, virtual_dataset):
# 4. turning pages
uri = f"/datasource/samples?datasource_id={virtual_dataset.id}&datasource_type=table&per_page=2&page=1"
rv = test_client.post(uri)
rv_data = json.loads(rv.data)
assert rv_data["result"]["page"] == 1
assert rv_data["result"]["per_page"] == 2
assert rv_data["result"]["total_count"] == 10
assert [row["col1"] for row in rv_data["result"]["data"]] == [0, 1]
assert rv.json["result"]["page"] == 1
assert rv.json["result"]["per_page"] == 2
assert rv.json["result"]["total_count"] == 10
assert [row["col1"] for row in rv.json["result"]["data"]] == [0, 1]

uri = f"/datasource/samples?datasource_id={virtual_dataset.id}&datasource_type=table&per_page=2&page=2"
rv = test_client.post(uri)
rv_data = json.loads(rv.data)
assert rv_data["result"]["page"] == 2
assert rv_data["result"]["per_page"] == 2
assert rv_data["result"]["total_count"] == 10
assert [row["col1"] for row in rv_data["result"]["data"]] == [2, 3]
assert rv.json["result"]["page"] == 2
assert rv.json["result"]["per_page"] == 2
assert rv.json["result"]["total_count"] == 10
assert [row["col1"] for row in rv.json["result"]["data"]] == [2, 3]

# 5. Exceeding the maximum pages
uri = f"/datasource/samples?datasource_id={virtual_dataset.id}&datasource_type=table&per_page=2&page=6"
rv = test_client.post(uri)
rv_data = json.loads(rv.data)
assert rv_data["result"]["page"] == 6
assert rv_data["result"]["per_page"] == 2
assert rv_data["result"]["total_count"] == 10
assert [row["col1"] for row in rv_data["result"]["data"]] == []
assert rv.json["result"]["page"] == 6
assert rv.json["result"]["per_page"] == 2
assert rv.json["result"]["total_count"] == 10
assert [row["col1"] for row in rv.json["result"]["data"]] == []

0 comments on commit 802b69f

Please sign in to comment.