diff --git a/superset/common/chart_data.py b/superset/common/chart_data.py index f3917d6d8717..ea31d4f13817 100644 --- a/superset/common/chart_data.py +++ b/superset/common/chart_data.py @@ -38,3 +38,4 @@ class ChartDataResultType(str, Enum): SAMPLES = "samples" TIMEGRAINS = "timegrains" POST_PROCESSED = "post_processed" + DRILL_DETAIL = "drill_detail" diff --git a/superset/common/query_actions.py b/superset/common/query_actions.py index 0764e19340c9..bfb3d368789d 100644 --- a/superset/common/query_actions.py +++ b/superset/common/query_actions.py @@ -162,6 +162,27 @@ def _get_samples( return _get_full(query_context, query_obj, force_cached) +def _get_drill_detail( + query_context: QueryContext, query_obj: QueryObject, force_cached: bool = False +) -> Dict[str, Any]: + # todo(yongjie): Remove this function, + # when determining whether samples should be applied to the time filter. + datasource = _get_datasource(query_context, query_obj) + query_obj = copy.copy(query_obj) + query_obj.is_timeseries = False + query_obj.orderby = [] + query_obj.metrics = None + query_obj.post_processing = [] + qry_obj_cols = [] + for o in datasource.columns: + if isinstance(o, dict): + qry_obj_cols.append(o.get("column_name")) + else: + qry_obj_cols.append(o.column_name) + query_obj.columns = qry_obj_cols + return _get_full(query_context, query_obj, force_cached) + + def _get_results( query_context: QueryContext, query_obj: QueryObject, force_cached: bool = False ) -> Dict[str, Any]: @@ -182,6 +203,7 @@ def _get_results( # and post-process it later where we have the chart context, since # post-processing is unique to each visualization type ChartDataResultType.POST_PROCESSED: _get_full, + ChartDataResultType.DRILL_DETAIL: _get_drill_detail, } diff --git a/superset/views/datasource/schemas.py b/superset/views/datasource/schemas.py index 4c97f17e88fe..f9be7a7d4e1f 100644 --- a/superset/views/datasource/schemas.py +++ b/superset/views/datasource/schemas.py @@ -20,7 +20,7 @@ from typing_extensions import TypedDict from superset import app -from superset.charts.schemas import ChartDataFilterSchema +from superset.charts.schemas import ChartDataExtrasSchema, ChartDataFilterSchema from superset.utils.core import DatasourceType @@ -62,6 +62,17 @@ def normalize( class SamplesPayloadSchema(Schema): filters = fields.List(fields.Nested(ChartDataFilterSchema), required=False) + granularity = fields.String( + allow_none=True, + ) + time_range = fields.String( + allow_none=True, + ) + extras = fields.Nested( + ChartDataExtrasSchema, + description="Extra parameters to add to the query.", + allow_none=True, + ) @pre_load # pylint: disable=no-self-use, unused-argument diff --git a/superset/views/datasource/utils.py b/superset/views/datasource/utils.py index 0191db2947c2..42cddf416794 100644 --- a/superset/views/datasource/utils.py +++ b/superset/views/datasource/utils.py @@ -60,17 +60,30 @@ def get_samples( # pylint: disable=too-many-arguments,too-many-locals limit_clause = get_limit_clause(page, per_page) # todo(yongjie): Constructing count(*) and samples in the same query_context, - # then remove query_type==SAMPLES - # constructing samples query - samples_instance = QueryContextFactory().create( - datasource={ - "type": datasource.type, - "id": datasource.id, - }, - queries=[{**payload, **limit_clause} if payload else limit_clause], - result_type=ChartDataResultType.SAMPLES, - force=force, - ) + if payload is None: + # constructing samples query + samples_instance = QueryContextFactory().create( + datasource={ + "type": datasource.type, + "id": datasource.id, + }, + queries=[limit_clause], + result_type=ChartDataResultType.SAMPLES, + force=force, + ) + else: + # constructing drill detail query + # When query_type == 'samples' the `time filter` will be removed, + # so it is not applicable drill detail query + samples_instance = QueryContextFactory().create( + datasource={ + "type": datasource.type, + "id": datasource.id, + }, + queries=[{**payload, **limit_clause}], + result_type=ChartDataResultType.DRILL_DETAIL, + force=force, + ) # constructing count(*) query count_star_metric = { diff --git a/tests/integration_tests/conftest.py b/tests/integration_tests/conftest.py index 043d7922193a..549a987db135 100644 --- a/tests/integration_tests/conftest.py +++ b/tests/integration_tests/conftest.py @@ -314,7 +314,7 @@ def physical_dataset(): col2 VARCHAR(255), col3 DECIMAL(4,2), col4 VARCHAR(255), - col5 VARCHAR(255) + col5 TIMESTAMP ); """ ) @@ -342,11 +342,10 @@ def physical_dataset(): TableColumn(column_name="col2", type="VARCHAR(255)", table=dataset) TableColumn(column_name="col3", type="DECIMAL(4,2)", table=dataset) TableColumn(column_name="col4", type="VARCHAR(255)", table=dataset) - TableColumn(column_name="col5", type="VARCHAR(255)", table=dataset) + TableColumn(column_name="col5", type="TIMESTAMP", is_dttm=True, table=dataset) SqlMetric(metric_name="count", expression="count(*)", table=dataset) db.session.merge(dataset) - if example_database.backend == "sqlite": - db.session.commit() + db.session.commit() yield dataset @@ -355,5 +354,7 @@ def physical_dataset(): DROP TABLE physical_dataset; """ ) - db.session.delete(dataset) + dataset = db.session.query(SqlaTable).filter_by(table_name="physical_dataset").all() + for ds in dataset: + db.session.delete(ds) db.session.commit() diff --git a/tests/integration_tests/datasource_tests.py b/tests/integration_tests/datasource_tests.py index ad4d625cc5ae..ef3ba0c69d6b 100644 --- a/tests/integration_tests/datasource_tests.py +++ b/tests/integration_tests/datasource_tests.py @@ -432,14 +432,13 @@ def test_get_samples(test_client, login_as_admin, virtual_dataset): test_client.post(uri) # get from cache rv = test_client.post(uri) - rv_data = json.loads(rv.data) assert rv.status_code == 200 - assert len(rv_data["result"]["data"]) == 10 + assert len(rv.json["result"]["data"]) == 10 assert QueryCacheManager.has( - rv_data["result"]["cache_key"], + rv.json["result"]["cache_key"], region=CacheRegion.DATA, ) - assert rv_data["result"]["is_cached"] + assert rv.json["result"]["is_cached"] # 2. should read through cache data uri2 = f"/datasource/samples?datasource_id={virtual_dataset.id}&datasource_type=table&force=true" @@ -447,19 +446,18 @@ def test_get_samples(test_client, login_as_admin, virtual_dataset): test_client.post(uri2) # force query rv2 = test_client.post(uri2) - rv_data2 = json.loads(rv2.data) assert rv2.status_code == 200 - assert len(rv_data2["result"]["data"]) == 10 + assert len(rv2.json["result"]["data"]) == 10 assert QueryCacheManager.has( - rv_data2["result"]["cache_key"], + rv2.json["result"]["cache_key"], region=CacheRegion.DATA, ) - assert not rv_data2["result"]["is_cached"] + assert not rv2.json["result"]["is_cached"] # 3. data precision - assert "colnames" in rv_data2["result"] - assert "coltypes" in rv_data2["result"] - assert "data" in rv_data2["result"] + assert "colnames" in rv2.json["result"] + assert "coltypes" in rv2.json["result"] + assert "data" in rv2.json["result"] eager_samples = virtual_dataset.database.get_df( f"select * from ({virtual_dataset.sql}) as tbl" @@ -468,7 +466,7 @@ def test_get_samples(test_client, login_as_admin, virtual_dataset): # the col3 is Decimal eager_samples["col3"] = eager_samples["col3"].apply(float) eager_samples = eager_samples.to_dict(orient="records") - assert eager_samples == rv_data2["result"]["data"] + assert eager_samples == rv2.json["result"]["data"] def test_get_samples_with_incorrect_cc(test_client, login_as_admin, virtual_dataset): @@ -486,10 +484,9 @@ def test_get_samples_with_incorrect_cc(test_client, login_as_admin, virtual_data rv = test_client.post(uri) assert rv.status_code == 422 - rv_data = json.loads(rv.data) - assert "error" in rv_data + assert "error" in rv.json if virtual_dataset.database.db_engine_spec.engine_name == "PostgreSQL": - assert "INCORRECT SQL" in rv_data.get("error") + assert "INCORRECT SQL" in rv.json.get("error") def test_get_samples_on_physical_dataset(test_client, login_as_admin, physical_dataset): @@ -498,11 +495,10 @@ def test_get_samples_on_physical_dataset(test_client, login_as_admin, physical_d ) rv = test_client.post(uri) assert rv.status_code == 200 - rv_data = json.loads(rv.data) assert QueryCacheManager.has( - rv_data["result"]["cache_key"], region=CacheRegion.DATA + rv.json["result"]["cache_key"], region=CacheRegion.DATA ) - assert len(rv_data["result"]["data"]) == 10 + assert len(rv.json["result"]["data"]) == 10 def test_get_samples_with_filters(test_client, login_as_admin, virtual_dataset): @@ -533,9 +529,8 @@ def test_get_samples_with_filters(test_client, login_as_admin, virtual_dataset): }, ) assert rv.status_code == 200 - rv_data = json.loads(rv.data) - assert rv_data["result"]["colnames"] == ["col1", "col2", "col3", "col4", "col5"] - assert rv_data["result"]["rowcount"] == 1 + assert rv.json["result"]["colnames"] == ["col1", "col2", "col3", "col4", "col5"] + assert rv.json["result"]["rowcount"] == 1 # empty results rv = test_client.post( @@ -547,9 +542,64 @@ def test_get_samples_with_filters(test_client, login_as_admin, virtual_dataset): }, ) assert rv.status_code == 200 - rv_data = json.loads(rv.data) - assert rv_data["result"]["colnames"] == [] - assert rv_data["result"]["rowcount"] == 0 + assert rv.json["result"]["colnames"] == [] + assert rv.json["result"]["rowcount"] == 0 + + +def test_get_samples_with_time_filter(test_client, login_as_admin, physical_dataset): + uri = ( + f"/datasource/samples?datasource_id={physical_dataset.id}&datasource_type=table" + ) + payload = { + "granularity": "col5", + "time_range": "2000-01-02 : 2000-01-04", + } + rv = test_client.post(uri, json=payload) + assert len(rv.json["result"]["data"]) == 2 + if physical_dataset.database.backend != "sqlite": + assert [row["col5"] for row in rv.json["result"]["data"]] == [ + 946771200000.0, # 2000-01-02 00:00:00 + 946857600000.0, # 2000-01-03 00:00:00 + ] + assert rv.json["result"]["page"] == 1 + assert rv.json["result"]["per_page"] == app.config["SAMPLES_ROW_LIMIT"] + assert rv.json["result"]["total_count"] == 2 + + +def test_get_samples_with_multiple_filters( + test_client, login_as_admin, physical_dataset +): + # 1. empty response + uri = ( + f"/datasource/samples?datasource_id={physical_dataset.id}&datasource_type=table" + ) + payload = { + "granularity": "col5", + "time_range": "2000-01-02 : 2000-01-04", + "filters": [ + {"col": "col4", "op": "IS NOT NULL"}, + ], + } + rv = test_client.post(uri, json=payload) + assert len(rv.json["result"]["data"]) == 0 + + # 2. adhoc filters, time filters, and custom where + payload = { + "granularity": "col5", + "time_range": "2000-01-02 : 2000-01-04", + "filters": [ + {"col": "col2", "op": "==", "val": "c"}, + ], + "extras": {"where": "col3 = 1.2 and col4 is null"}, + } + rv = test_client.post(uri, json=payload) + assert len(rv.json["result"]["data"]) == 1 + assert rv.json["result"]["total_count"] == 1 + assert "2000-01-02" in rv.json["result"]["query"] + assert "2000-01-04" in rv.json["result"]["query"] + assert "col3 = 1.2" in rv.json["result"]["query"] + assert "col4 is null" in rv.json["result"]["query"] + assert "col2 = 'c'" in rv.json["result"]["query"] def test_get_samples_pagination(test_client, login_as_admin, virtual_dataset): @@ -558,10 +608,9 @@ def test_get_samples_pagination(test_client, login_as_admin, virtual_dataset): f"/datasource/samples?datasource_id={virtual_dataset.id}&datasource_type=table" ) rv = test_client.post(uri) - rv_data = json.loads(rv.data) - assert rv_data["result"]["page"] == 1 - assert rv_data["result"]["per_page"] == app.config["SAMPLES_ROW_LIMIT"] - assert rv_data["result"]["total_count"] == 10 + assert rv.json["result"]["page"] == 1 + assert rv.json["result"]["per_page"] == app.config["SAMPLES_ROW_LIMIT"] + assert rv.json["result"]["total_count"] == 10 # 2. incorrect per_page per_pages = (app.config["SAMPLES_ROW_LIMIT"] + 1, 0, "xx") @@ -582,25 +631,22 @@ def test_get_samples_pagination(test_client, login_as_admin, virtual_dataset): # 4. turning pages uri = f"/datasource/samples?datasource_id={virtual_dataset.id}&datasource_type=table&per_page=2&page=1" rv = test_client.post(uri) - rv_data = json.loads(rv.data) - assert rv_data["result"]["page"] == 1 - assert rv_data["result"]["per_page"] == 2 - assert rv_data["result"]["total_count"] == 10 - assert [row["col1"] for row in rv_data["result"]["data"]] == [0, 1] + assert rv.json["result"]["page"] == 1 + assert rv.json["result"]["per_page"] == 2 + assert rv.json["result"]["total_count"] == 10 + assert [row["col1"] for row in rv.json["result"]["data"]] == [0, 1] uri = f"/datasource/samples?datasource_id={virtual_dataset.id}&datasource_type=table&per_page=2&page=2" rv = test_client.post(uri) - rv_data = json.loads(rv.data) - assert rv_data["result"]["page"] == 2 - assert rv_data["result"]["per_page"] == 2 - assert rv_data["result"]["total_count"] == 10 - assert [row["col1"] for row in rv_data["result"]["data"]] == [2, 3] + assert rv.json["result"]["page"] == 2 + assert rv.json["result"]["per_page"] == 2 + assert rv.json["result"]["total_count"] == 10 + assert [row["col1"] for row in rv.json["result"]["data"]] == [2, 3] # 5. Exceeding the maximum pages uri = f"/datasource/samples?datasource_id={virtual_dataset.id}&datasource_type=table&per_page=2&page=6" rv = test_client.post(uri) - rv_data = json.loads(rv.data) - assert rv_data["result"]["page"] == 6 - assert rv_data["result"]["per_page"] == 2 - assert rv_data["result"]["total_count"] == 10 - assert [row["col1"] for row in rv_data["result"]["data"]] == [] + assert rv.json["result"]["page"] == 6 + assert rv.json["result"]["per_page"] == 2 + assert rv.json["result"]["total_count"] == 10 + assert [row["col1"] for row in rv.json["result"]["data"]] == []