Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: supports mulitple filters in samples endpoint #21008

Merged
merged 4 commits into from
Aug 8, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions superset/common/chart_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,3 +38,4 @@ class ChartDataResultType(str, Enum):
SAMPLES = "samples"
TIMEGRAINS = "timegrains"
POST_PROCESSED = "post_processed"
DRILL_DETAIL = "drill_detail"
22 changes: 22 additions & 0 deletions superset/common/query_actions.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,27 @@ def _get_samples(
return _get_full(query_context, query_obj, force_cached)


def _get_drill_detail(
query_context: QueryContext, query_obj: QueryObject, force_cached: bool = False
) -> Dict[str, Any]:
# todo(yongjie): Remove this function,
# when determining whether samples should be applied to the time filter.
datasource = _get_datasource(query_context, query_obj)
query_obj = copy.copy(query_obj)
query_obj.is_timeseries = False
query_obj.orderby = []
query_obj.metrics = None
query_obj.post_processing = []
qry_obj_cols = []
for o in datasource.columns:
if isinstance(o, dict):
qry_obj_cols.append(o.get("column_name"))
else:
qry_obj_cols.append(o.column_name)
query_obj.columns = qry_obj_cols
return _get_full(query_context, query_obj, force_cached)


def _get_results(
query_context: QueryContext, query_obj: QueryObject, force_cached: bool = False
) -> Dict[str, Any]:
Expand All @@ -182,6 +203,7 @@ def _get_results(
# and post-process it later where we have the chart context, since
# post-processing is unique to each visualization type
ChartDataResultType.POST_PROCESSED: _get_full,
ChartDataResultType.DRILL_DETAIL: _get_drill_detail,
}


Expand Down
13 changes: 12 additions & 1 deletion superset/views/datasource/schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from typing_extensions import TypedDict

from superset import app
from superset.charts.schemas import ChartDataFilterSchema
from superset.charts.schemas import ChartDataExtrasSchema, ChartDataFilterSchema
from superset.utils.core import DatasourceType


Expand Down Expand Up @@ -62,6 +62,17 @@ def normalize(

class SamplesPayloadSchema(Schema):
filters = fields.List(fields.Nested(ChartDataFilterSchema), required=False)
granularity = fields.String(
allow_none=True,
)
time_range = fields.String(
allow_none=True,
)
extras = fields.Nested(
ChartDataExtrasSchema,
description="Extra parameters to add to the query.",
allow_none=True,
)

@pre_load
# pylint: disable=no-self-use, unused-argument
Expand Down
35 changes: 24 additions & 11 deletions superset/views/datasource/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,17 +60,30 @@ def get_samples( # pylint: disable=too-many-arguments,too-many-locals
limit_clause = get_limit_clause(page, per_page)

# todo(yongjie): Constructing count(*) and samples in the same query_context,
# then remove query_type==SAMPLES
# constructing samples query
samples_instance = QueryContextFactory().create(
datasource={
"type": datasource.type,
"id": datasource.id,
},
queries=[{**payload, **limit_clause} if payload else limit_clause],
result_type=ChartDataResultType.SAMPLES,
force=force,
)
if payload is None:
# constructing samples query
samples_instance = QueryContextFactory().create(
datasource={
"type": datasource.type,
"id": datasource.id,
},
queries=[limit_clause],
result_type=ChartDataResultType.SAMPLES,
force=force,
)
else:
# constructing drill detail query
# When query_type == 'samples' the `time filter` will be removed,
# so it is not applicable drill detail query
samples_instance = QueryContextFactory().create(
datasource={
"type": datasource.type,
"id": datasource.id,
},
queries=[{**payload, **limit_clause}],
result_type=ChartDataResultType.DRILL_DETAIL,
force=force,
)

# constructing count(*) query
count_star_metric = {
Expand Down
11 changes: 6 additions & 5 deletions tests/integration_tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -314,7 +314,7 @@ def physical_dataset():
col2 VARCHAR(255),
col3 DECIMAL(4,2),
col4 VARCHAR(255),
col5 VARCHAR(255)
col5 TIMESTAMP
);
"""
)
Expand Down Expand Up @@ -342,11 +342,10 @@ def physical_dataset():
TableColumn(column_name="col2", type="VARCHAR(255)", table=dataset)
TableColumn(column_name="col3", type="DECIMAL(4,2)", table=dataset)
TableColumn(column_name="col4", type="VARCHAR(255)", table=dataset)
TableColumn(column_name="col5", type="VARCHAR(255)", table=dataset)
TableColumn(column_name="col5", type="TIMESTAMP", is_dttm=True, table=dataset)
SqlMetric(metric_name="count", expression="count(*)", table=dataset)
db.session.merge(dataset)
if example_database.backend == "sqlite":
db.session.commit()
db.session.commit()

yield dataset

Expand All @@ -355,5 +354,7 @@ def physical_dataset():
DROP TABLE physical_dataset;
"""
)
db.session.delete(dataset)
dataset = db.session.query(SqlaTable).filter_by(table_name="physical_dataset").all()
for ds in dataset:
db.session.delete(ds)
db.session.commit()
132 changes: 89 additions & 43 deletions tests/integration_tests/datasource_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -432,34 +432,32 @@ def test_get_samples(test_client, login_as_admin, virtual_dataset):
test_client.post(uri)
# get from cache
rv = test_client.post(uri)
rv_data = json.loads(rv.data)
assert rv.status_code == 200
assert len(rv_data["result"]["data"]) == 10
assert len(rv.json["result"]["data"]) == 10
assert QueryCacheManager.has(
rv_data["result"]["cache_key"],
rv.json["result"]["cache_key"],
region=CacheRegion.DATA,
)
assert rv_data["result"]["is_cached"]
assert rv.json["result"]["is_cached"]

# 2. should read through cache data
uri2 = f"/datasource/samples?datasource_id={virtual_dataset.id}&datasource_type=table&force=true"
# feeds data
test_client.post(uri2)
# force query
rv2 = test_client.post(uri2)
rv_data2 = json.loads(rv2.data)
assert rv2.status_code == 200
assert len(rv_data2["result"]["data"]) == 10
assert len(rv2.json["result"]["data"]) == 10
assert QueryCacheManager.has(
rv_data2["result"]["cache_key"],
rv2.json["result"]["cache_key"],
region=CacheRegion.DATA,
)
assert not rv_data2["result"]["is_cached"]
assert not rv2.json["result"]["is_cached"]

# 3. data precision
assert "colnames" in rv_data2["result"]
assert "coltypes" in rv_data2["result"]
assert "data" in rv_data2["result"]
assert "colnames" in rv2.json["result"]
assert "coltypes" in rv2.json["result"]
assert "data" in rv2.json["result"]

eager_samples = virtual_dataset.database.get_df(
f"select * from ({virtual_dataset.sql}) as tbl"
Expand All @@ -468,7 +466,7 @@ def test_get_samples(test_client, login_as_admin, virtual_dataset):
# the col3 is Decimal
eager_samples["col3"] = eager_samples["col3"].apply(float)
eager_samples = eager_samples.to_dict(orient="records")
assert eager_samples == rv_data2["result"]["data"]
assert eager_samples == rv2.json["result"]["data"]


def test_get_samples_with_incorrect_cc(test_client, login_as_admin, virtual_dataset):
Expand All @@ -486,10 +484,9 @@ def test_get_samples_with_incorrect_cc(test_client, login_as_admin, virtual_data
rv = test_client.post(uri)
assert rv.status_code == 422

rv_data = json.loads(rv.data)
assert "error" in rv_data
assert "error" in rv.json
if virtual_dataset.database.db_engine_spec.engine_name == "PostgreSQL":
assert "INCORRECT SQL" in rv_data.get("error")
assert "INCORRECT SQL" in rv.json.get("error")


def test_get_samples_on_physical_dataset(test_client, login_as_admin, physical_dataset):
Expand All @@ -498,11 +495,10 @@ def test_get_samples_on_physical_dataset(test_client, login_as_admin, physical_d
)
rv = test_client.post(uri)
assert rv.status_code == 200
rv_data = json.loads(rv.data)
assert QueryCacheManager.has(
rv_data["result"]["cache_key"], region=CacheRegion.DATA
rv.json["result"]["cache_key"], region=CacheRegion.DATA
)
assert len(rv_data["result"]["data"]) == 10
assert len(rv.json["result"]["data"]) == 10


def test_get_samples_with_filters(test_client, login_as_admin, virtual_dataset):
Expand Down Expand Up @@ -533,9 +529,8 @@ def test_get_samples_with_filters(test_client, login_as_admin, virtual_dataset):
},
)
assert rv.status_code == 200
rv_data = json.loads(rv.data)
assert rv_data["result"]["colnames"] == ["col1", "col2", "col3", "col4", "col5"]
assert rv_data["result"]["rowcount"] == 1
assert rv.json["result"]["colnames"] == ["col1", "col2", "col3", "col4", "col5"]
assert rv.json["result"]["rowcount"] == 1

# empty results
rv = test_client.post(
Expand All @@ -547,9 +542,64 @@ def test_get_samples_with_filters(test_client, login_as_admin, virtual_dataset):
},
)
assert rv.status_code == 200
rv_data = json.loads(rv.data)
assert rv_data["result"]["colnames"] == []
assert rv_data["result"]["rowcount"] == 0
assert rv.json["result"]["colnames"] == []
assert rv.json["result"]["rowcount"] == 0


def test_get_samples_with_time_filter(test_client, login_as_admin, physical_dataset):
uri = (
f"/datasource/samples?datasource_id={physical_dataset.id}&datasource_type=table"
)
payload = {
"granularity": "col5",
"time_range": "2000-01-02 : 2000-01-04",
}
rv = test_client.post(uri, json=payload)
assert len(rv.json["result"]["data"]) == 2
if physical_dataset.database.backend != "sqlite":
assert [row["col5"] for row in rv.json["result"]["data"]] == [
946771200000.0, # 2000-01-02 00:00:00
946857600000.0, # 2000-01-03 00:00:00
]
assert rv.json["result"]["page"] == 1
assert rv.json["result"]["per_page"] == app.config["SAMPLES_ROW_LIMIT"]
assert rv.json["result"]["total_count"] == 2


def test_get_samples_with_multiple_filters(
test_client, login_as_admin, physical_dataset
):
# 1. empty response
uri = (
f"/datasource/samples?datasource_id={physical_dataset.id}&datasource_type=table"
)
payload = {
"granularity": "col5",
"time_range": "2000-01-02 : 2000-01-04",
"filters": [
{"col": "col4", "op": "IS NOT NULL"},
],
}
rv = test_client.post(uri, json=payload)
assert len(rv.json["result"]["data"]) == 0

# 2. adhoc filters, time filters, and custom where
payload = {
"granularity": "col5",
"time_range": "2000-01-02 : 2000-01-04",
"filters": [
{"col": "col2", "op": "==", "val": "c"},
],
"extras": {"where": "col3 = 1.2 and col4 is null"},
}
rv = test_client.post(uri, json=payload)
assert len(rv.json["result"]["data"]) == 1
assert rv.json["result"]["total_count"] == 1
assert "2000-01-02" in rv.json["result"]["query"]
assert "2000-01-04" in rv.json["result"]["query"]
assert "col3 = 1.2" in rv.json["result"]["query"]
assert "col4 is null" in rv.json["result"]["query"]
assert "col2 = 'c'" in rv.json["result"]["query"]


def test_get_samples_pagination(test_client, login_as_admin, virtual_dataset):
Expand All @@ -558,10 +608,9 @@ def test_get_samples_pagination(test_client, login_as_admin, virtual_dataset):
f"/datasource/samples?datasource_id={virtual_dataset.id}&datasource_type=table"
)
rv = test_client.post(uri)
rv_data = json.loads(rv.data)
assert rv_data["result"]["page"] == 1
assert rv_data["result"]["per_page"] == app.config["SAMPLES_ROW_LIMIT"]
assert rv_data["result"]["total_count"] == 10
assert rv.json["result"]["page"] == 1
assert rv.json["result"]["per_page"] == app.config["SAMPLES_ROW_LIMIT"]
assert rv.json["result"]["total_count"] == 10

# 2. incorrect per_page
per_pages = (app.config["SAMPLES_ROW_LIMIT"] + 1, 0, "xx")
Expand All @@ -582,25 +631,22 @@ def test_get_samples_pagination(test_client, login_as_admin, virtual_dataset):
# 4. turning pages
uri = f"/datasource/samples?datasource_id={virtual_dataset.id}&datasource_type=table&per_page=2&page=1"
rv = test_client.post(uri)
rv_data = json.loads(rv.data)
assert rv_data["result"]["page"] == 1
assert rv_data["result"]["per_page"] == 2
assert rv_data["result"]["total_count"] == 10
assert [row["col1"] for row in rv_data["result"]["data"]] == [0, 1]
assert rv.json["result"]["page"] == 1
assert rv.json["result"]["per_page"] == 2
assert rv.json["result"]["total_count"] == 10
assert [row["col1"] for row in rv.json["result"]["data"]] == [0, 1]

uri = f"/datasource/samples?datasource_id={virtual_dataset.id}&datasource_type=table&per_page=2&page=2"
rv = test_client.post(uri)
rv_data = json.loads(rv.data)
assert rv_data["result"]["page"] == 2
assert rv_data["result"]["per_page"] == 2
assert rv_data["result"]["total_count"] == 10
assert [row["col1"] for row in rv_data["result"]["data"]] == [2, 3]
assert rv.json["result"]["page"] == 2
assert rv.json["result"]["per_page"] == 2
assert rv.json["result"]["total_count"] == 10
assert [row["col1"] for row in rv.json["result"]["data"]] == [2, 3]

# 5. Exceeding the maximum pages
uri = f"/datasource/samples?datasource_id={virtual_dataset.id}&datasource_type=table&per_page=2&page=6"
rv = test_client.post(uri)
rv_data = json.loads(rv.data)
assert rv_data["result"]["page"] == 6
assert rv_data["result"]["per_page"] == 2
assert rv_data["result"]["total_count"] == 10
assert [row["col1"] for row in rv_data["result"]["data"]] == []
assert rv.json["result"]["page"] == 6
assert rv.json["result"]["per_page"] == 2
assert rv.json["result"]["total_count"] == 10
assert [row["col1"] for row in rv.json["result"]["data"]] == []