From b9500466682c63fe447d2dfe7d539e80589016e2 Mon Sep 17 00:00:00 2001 From: Mayur Date: Mon, 10 Oct 2022 09:38:33 +0530 Subject: [PATCH] fix: allow adhoc columns in non-aggregate query (#21729) --- superset/connectors/sqla/models.py | 18 ++++++-- superset/superset_typing.py | 4 +- superset/utils/core.py | 4 +- .../charts/data/api_tests.py | 41 +++++++++++++++++++ 4 files changed, 60 insertions(+), 7 deletions(-) diff --git a/superset/connectors/sqla/models.py b/superset/connectors/sqla/models.py index f06b7fa0b0c8..efd67cdd9a95 100644 --- a/superset/connectors/sqla/models.py +++ b/superset/connectors/sqla/models.py @@ -116,6 +116,7 @@ from superset.superset_typing import ( AdhocColumn, AdhocMetric, + Column as ColumnTyping, Metric, OrderBy, QueryObjectDict, @@ -1216,7 +1217,7 @@ def text(self, clause: str) -> TextClause: def get_sqla_query( # pylint: disable=too-many-arguments,too-many-locals,too-many-branches,too-many-statements self, apply_fetch_values_predicate: bool = False, - columns: Optional[List[Column]] = None, + columns: Optional[List[ColumnTyping]] = None, extras: Optional[Dict[str, Any]] = None, filter: Optional[ # pylint: disable=redefined-builtin List[QueryObjectFilterClause] @@ -1412,15 +1413,24 @@ def get_sqla_query( # pylint: disable=too-many-arguments,too-many-locals,too-ma select_exprs.append(outer) elif columns: for selected in columns: + if is_adhoc_column(selected): + _sql = selected["sqlExpression"] + _column_label = selected["label"] + elif isinstance(selected, str): + _sql = selected + _column_label = selected + selected = validate_adhoc_subquery( - selected, + _sql, self.database_id, self.schema, ) select_exprs.append( columns_by_name[selected].get_sqla_col() - if selected in columns_by_name - else self.make_sqla_column_compatible(literal_column(selected)) + if isinstance(selected, str) and selected in columns_by_name + else self.make_sqla_column_compatible( + literal_column(selected), _column_label + ) ) metrics_exprs = [] diff --git a/superset/superset_typing.py b/superset/superset_typing.py index ae8787d1c691..00b76cd78cae 100644 --- a/superset/superset_typing.py +++ b/superset/superset_typing.py @@ -53,8 +53,8 @@ class AdhocMetric(TypedDict, total=False): class AdhocColumn(TypedDict, total=False): hasCustomLabel: Optional[bool] - label: Optional[str] - sqlExpression: Optional[str] + label: str + sqlExpression: str class ResultSetColumnType(TypedDict): diff --git a/superset/utils/core.py b/superset/utils/core.py index 5e84c0e3caa9..95bf76c5fa3d 100644 --- a/superset/utils/core.py +++ b/superset/utils/core.py @@ -1253,7 +1253,9 @@ def is_adhoc_metric(metric: Metric) -> TypeGuard[AdhocMetric]: def is_adhoc_column(column: Column) -> TypeGuard[AdhocColumn]: - return isinstance(column, dict) + return isinstance(column, dict) and ({"label", "sqlExpression"}).issubset( + column.keys() + ) def get_column_name( diff --git a/tests/integration_tests/charts/data/api_tests.py b/tests/integration_tests/charts/data/api_tests.py index 73425fb58f68..6bbed00759c6 100644 --- a/tests/integration_tests/charts/data/api_tests.py +++ b/tests/integration_tests/charts/data/api_tests.py @@ -766,6 +766,47 @@ def test_with_virtual_table_with_colons_as_datasource(self): assert "':xyz:qwerty'" in result["query"] assert "':qwerty:'" in result["query"] + @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices") + def test_with_table_columns_without_metrics(self): + request_payload = self.query_context_payload + request_payload["queries"][0]["columns"] = ["name", "gender"] + request_payload["queries"][0]["metrics"] = None + request_payload["queries"][0]["orderby"] = [] + + rv = self.post_assert_metric(CHART_DATA_URI, request_payload, "data") + result = rv.json["result"][0] + + assert rv.status_code == 200 + assert "name" in result["colnames"] + assert "gender" in result["colnames"] + assert "name" in result["query"] + assert "gender" in result["query"] + assert list(result["data"][0].keys()) == ["name", "gender"] + + @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices") + def test_with_adhoc_column_without_metrics(self): + request_payload = self.query_context_payload + request_payload["queries"][0]["columns"] = [ + "name", + { + "label": "num divide by 10", + "sqlExpression": "num/10", + "expressionType": "SQL", + }, + ] + request_payload["queries"][0]["metrics"] = None + request_payload["queries"][0]["orderby"] = [] + + rv = self.post_assert_metric(CHART_DATA_URI, request_payload, "data") + result = rv.json["result"][0] + + assert rv.status_code == 200 + assert "num divide by 10" in result["colnames"] + assert "name" in result["colnames"] + assert "num divide by 10" in result["query"] + assert "name" in result["query"] + assert list(result["data"][0].keys()) == ["name", "num divide by 10"] + @pytest.mark.chart_data_flow class TestGetChartDataApi(BaseTestChartDataApi):