Skip to content

Commit

Permalink
fix: allow adhoc columns in non-aggregate query (apache#21729)
Browse files Browse the repository at this point in the history
  • Loading branch information
mayurnewase authored and Fahrenheit35 committed Nov 11, 2022
1 parent 421c95d commit b950046
Show file tree
Hide file tree
Showing 4 changed files with 60 additions and 7 deletions.
18 changes: 14 additions & 4 deletions superset/connectors/sqla/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,7 @@
from superset.superset_typing import (
AdhocColumn,
AdhocMetric,
Column as ColumnTyping,
Metric,
OrderBy,
QueryObjectDict,
Expand Down Expand Up @@ -1216,7 +1217,7 @@ def text(self, clause: str) -> TextClause:
def get_sqla_query( # pylint: disable=too-many-arguments,too-many-locals,too-many-branches,too-many-statements
self,
apply_fetch_values_predicate: bool = False,
columns: Optional[List[Column]] = None,
columns: Optional[List[ColumnTyping]] = None,
extras: Optional[Dict[str, Any]] = None,
filter: Optional[ # pylint: disable=redefined-builtin
List[QueryObjectFilterClause]
Expand Down Expand Up @@ -1412,15 +1413,24 @@ def get_sqla_query( # pylint: disable=too-many-arguments,too-many-locals,too-ma
select_exprs.append(outer)
elif columns:
for selected in columns:
if is_adhoc_column(selected):
_sql = selected["sqlExpression"]
_column_label = selected["label"]
elif isinstance(selected, str):
_sql = selected
_column_label = selected

selected = validate_adhoc_subquery(
selected,
_sql,
self.database_id,
self.schema,
)
select_exprs.append(
columns_by_name[selected].get_sqla_col()
if selected in columns_by_name
else self.make_sqla_column_compatible(literal_column(selected))
if isinstance(selected, str) and selected in columns_by_name
else self.make_sqla_column_compatible(
literal_column(selected), _column_label
)
)
metrics_exprs = []

Expand Down
4 changes: 2 additions & 2 deletions superset/superset_typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,8 +53,8 @@ class AdhocMetric(TypedDict, total=False):

class AdhocColumn(TypedDict, total=False):
hasCustomLabel: Optional[bool]
label: Optional[str]
sqlExpression: Optional[str]
label: str
sqlExpression: str


class ResultSetColumnType(TypedDict):
Expand Down
4 changes: 3 additions & 1 deletion superset/utils/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -1253,7 +1253,9 @@ def is_adhoc_metric(metric: Metric) -> TypeGuard[AdhocMetric]:


def is_adhoc_column(column: Column) -> TypeGuard[AdhocColumn]:
return isinstance(column, dict)
return isinstance(column, dict) and ({"label", "sqlExpression"}).issubset(
column.keys()
)


def get_column_name(
Expand Down
41 changes: 41 additions & 0 deletions tests/integration_tests/charts/data/api_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -766,6 +766,47 @@ def test_with_virtual_table_with_colons_as_datasource(self):
assert "':xyz:qwerty'" in result["query"]
assert "':qwerty:'" in result["query"]

@pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
def test_with_table_columns_without_metrics(self):
request_payload = self.query_context_payload
request_payload["queries"][0]["columns"] = ["name", "gender"]
request_payload["queries"][0]["metrics"] = None
request_payload["queries"][0]["orderby"] = []

rv = self.post_assert_metric(CHART_DATA_URI, request_payload, "data")
result = rv.json["result"][0]

assert rv.status_code == 200
assert "name" in result["colnames"]
assert "gender" in result["colnames"]
assert "name" in result["query"]
assert "gender" in result["query"]
assert list(result["data"][0].keys()) == ["name", "gender"]

@pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
def test_with_adhoc_column_without_metrics(self):
request_payload = self.query_context_payload
request_payload["queries"][0]["columns"] = [
"name",
{
"label": "num divide by 10",
"sqlExpression": "num/10",
"expressionType": "SQL",
},
]
request_payload["queries"][0]["metrics"] = None
request_payload["queries"][0]["orderby"] = []

rv = self.post_assert_metric(CHART_DATA_URI, request_payload, "data")
result = rv.json["result"][0]

assert rv.status_code == 200
assert "num divide by 10" in result["colnames"]
assert "name" in result["colnames"]
assert "num divide by 10" in result["query"]
assert "name" in result["query"]
assert list(result["data"][0].keys()) == ["name", "num divide by 10"]


@pytest.mark.chart_data_flow
class TestGetChartDataApi(BaseTestChartDataApi):
Expand Down

0 comments on commit b950046

Please sign in to comment.