Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(csv-export): pivot v2 with verbose names #18633

Merged
merged 3 commits into from
Feb 9, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion superset/charts/data/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,7 +241,9 @@ def data(self) -> Response:
return self._run_async(json_body, command)

form_data = json_body.get("form_data")
return self._get_data_response(command, form_data=form_data)
return self._get_data_response(
command, form_data=form_data, datasource=query_context.datasource
)

@expose("/data/<cache_key>", methods=["GET"])
@protect()
Expand Down
41 changes: 30 additions & 11 deletions superset/charts/post_processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,12 @@
import pandas as pd

from superset.common.chart_data import ChartDataResultFormat
from superset.utils.core import DTTM_ALIAS, extract_dataframe_dtypes, get_metric_name
from superset.utils.core import (
DTTM_ALIAS,
extract_dataframe_dtypes,
get_column_names,
get_metric_names,
)

if TYPE_CHECKING:
from superset.connectors.base.models import BaseDatasource
Expand Down Expand Up @@ -214,18 +219,23 @@ def list_unique_values(series: pd.Series) -> str:
}


def pivot_table_v2(df: pd.DataFrame, form_data: Dict[str, Any]) -> pd.DataFrame:
def pivot_table_v2(
df: pd.DataFrame,
form_data: Dict[str, Any],
datasource: Optional["BaseDatasource"] = None,
) -> pd.DataFrame:
"""
Pivot table v2.
"""
verbose_map = datasource.data["verbose_map"] if datasource else None
if form_data.get("granularity_sqla") == "all" and DTTM_ALIAS in df:
del df[DTTM_ALIAS]

return pivot_df(
df,
rows=form_data.get("groupbyRows") or [],
columns=form_data.get("groupbyColumns") or [],
metrics=[get_metric_name(m) for m in form_data["metrics"]],
rows=get_column_names(form_data.get("groupbyRows"), verbose_map),
columns=get_column_names(form_data.get("groupbyColumns"), verbose_map),
metrics=get_metric_names(form_data["metrics"], verbose_map),
aggfunc=form_data.get("aggregateFunction", "Sum"),
transpose_pivot=bool(form_data.get("transposePivot")),
combine_metrics=bool(form_data.get("combineMetric")),
Expand All @@ -235,10 +245,15 @@ def pivot_table_v2(df: pd.DataFrame, form_data: Dict[str, Any]) -> pd.DataFrame:
)


def pivot_table(df: pd.DataFrame, form_data: Dict[str, Any]) -> pd.DataFrame:
def pivot_table(
df: pd.DataFrame,
form_data: Dict[str, Any],
datasource: Optional["BaseDatasource"] = None,
) -> pd.DataFrame:
"""
Pivot table (v1).
"""
verbose_map = datasource.data["verbose_map"] if datasource else None
if form_data.get("granularity") == "all" and DTTM_ALIAS in df:
del df[DTTM_ALIAS]

Expand All @@ -254,9 +269,9 @@ def pivot_table(df: pd.DataFrame, form_data: Dict[str, Any]) -> pd.DataFrame:

return pivot_df(
df,
rows=form_data.get("groupby") or [],
columns=form_data.get("columns") or [],
metrics=[get_metric_name(m) for m in form_data["metrics"]],
rows=get_column_names(form_data.get("groupby"), verbose_map),
columns=get_column_names(form_data.get("columns"), verbose_map),
metrics=get_metric_names(form_data["metrics"], verbose_map),
aggfunc=func_map.get(form_data.get("pandas_aggfunc", "sum"), "Sum"),
transpose_pivot=bool(form_data.get("transpose_pivot")),
combine_metrics=bool(form_data.get("combine_metric")),
Expand All @@ -266,7 +281,11 @@ def pivot_table(df: pd.DataFrame, form_data: Dict[str, Any]) -> pd.DataFrame:
)


def table(df: pd.DataFrame, form_data: Dict[str, Any]) -> pd.DataFrame:
def table(
df: pd.DataFrame,
form_data: Dict[str, Any],
datasource: Optional["BaseDatasource"] = None, # pylint: disable=unused-argument
) -> pd.DataFrame:
"""
Table.
"""
Expand Down Expand Up @@ -312,7 +331,7 @@ def apply_post_process(
else:
raise Exception(f"Result format {query['result_format']} not supported")

processed_df = post_processor(df, form_data)
processed_df = post_processor(df, form_data, datasource)

query["colnames"] = list(processed_df.columns)
query["indexnames"] = list(processed_df.index)
Expand Down
47 changes: 36 additions & 11 deletions superset/utils/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -1228,11 +1228,15 @@ def is_adhoc_column(column: Column) -> TypeGuard[AdhocColumn]:
return isinstance(column, dict)


def get_column_name(column: Column) -> str:
def get_column_name(
column: Column, verbose_map: Optional[Dict[str, Any]] = None
) -> str:
"""
Extract label from column

:param column: object to extract label from
:param verbose_map: verbose_map from dataset for optional mapping from
raw name to verbose name
:return: String representation of column
:raises ValueError: if metric object is invalid
"""
Expand All @@ -1243,15 +1247,20 @@ def get_column_name(column: Column) -> str:
expr = column.get("sqlExpression")
if expr:
return expr
raise Exception("Missing label")
return column
raise ValueError("Missing label")
verbose_map = verbose_map or {}
return verbose_map.get(column, column)


def get_metric_name(metric: Metric) -> str:
def get_metric_name(
metric: Metric, verbose_map: Optional[Dict[str, Any]] = None
) -> str:
"""
Extract label from metric

:param metric: object to extract label from
:param verbose_map: verbose_map from dataset for optional mapping from
raw name to verbose name
:return: String representation of metric
:raises ValueError: if metric object is invalid
"""
Expand All @@ -1273,19 +1282,35 @@ def get_metric_name(metric: Metric) -> str:
if column_name:
return column_name
raise ValueError(__("Invalid metric object"))
return metric # type: ignore

verbose_map = verbose_map or {}
return verbose_map.get(metric, metric) # type: ignore


def get_column_names(columns: Optional[Sequence[Column]]) -> List[str]:
return [column for column in map(get_column_name, columns or []) if column]
def get_column_names(
columns: Optional[Sequence[Column]], verbose_map: Optional[Dict[str, Any]] = None,
) -> List[str]:
return [
column
for column in [get_column_name(column, verbose_map) for column in columns or []]
if column
]


def get_metric_names(metrics: Optional[Sequence[Metric]]) -> List[str]:
return [metric for metric in map(get_metric_name, metrics or []) if metric]
def get_metric_names(
metrics: Optional[Sequence[Metric]], verbose_map: Optional[Dict[str, Any]] = None,
) -> List[str]:
return [
metric
for metric in [get_metric_name(metric, verbose_map) for metric in metrics or []]
if metric
]


def get_first_metric_name(metrics: Optional[Sequence[Metric]]) -> Optional[str]:
metric_labels = get_metric_names(metrics)
def get_first_metric_name(
metrics: Optional[Sequence[Metric]], verbose_map: Optional[Dict[str, Any]] = None,
) -> Optional[str]:
metric_labels = get_metric_names(metrics, verbose_map)
return metric_labels[0] if metric_labels else None


Expand Down
53 changes: 53 additions & 0 deletions tests/unit_tests/core_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,13 @@
import pytest

from superset.utils.core import (
AdhocColumn,
AdhocMetric,
ExtraFiltersReasonType,
ExtraFiltersTimeColumnType,
GenericDataType,
get_column_name,
get_column_names,
get_metric_name,
get_metric_names,
get_time_filter_status,
Expand All @@ -47,15 +50,23 @@
"label": "my_sql",
"sqlExpression": "SUM(my_col)",
}
STR_COLUMN = "my_column"
SQL_ADHOC_COLUMN: AdhocColumn = {
"hasCustomLabel": True,
"label": "My Adhoc Column",
"sqlExpression": "case when foo = 1 then 'foo' else 'bar' end",
}


def test_get_metric_name_saved_metric():
assert get_metric_name(STR_METRIC) == "my_metric"
assert get_metric_name(STR_METRIC, {STR_METRIC: "My Metric"}) == "My Metric"


def test_get_metric_name_adhoc():
metric = deepcopy(SIMPLE_SUM_ADHOC_METRIC)
assert get_metric_name(metric) == "my SUM"
assert get_metric_name(metric, {"my SUM": "My Irrelevant Mapping"}) == "my SUM"
del metric["label"]
assert get_metric_name(metric) == "SUM(my_col)"
metric["label"] = ""
Expand All @@ -64,9 +75,11 @@ def test_get_metric_name_adhoc():
assert get_metric_name(metric) == "my_col"
metric["aggregate"] = ""
assert get_metric_name(metric) == "my_col"
assert get_metric_name(metric, {"my_col": "My Irrelevant Mapping"}) == "my_col"

metric = deepcopy(SQL_ADHOC_METRIC)
assert get_metric_name(metric) == "my_sql"
assert get_metric_name(metric, {"my_sql": "My Irrelevant Mapping"}) == "my_sql"
del metric["label"]
assert get_metric_name(metric) == "SUM(my_col)"
metric["label"] = ""
Expand Down Expand Up @@ -97,6 +110,46 @@ def test_get_metric_names():
assert get_metric_names(
[STR_METRIC, SIMPLE_SUM_ADHOC_METRIC, SQL_ADHOC_METRIC]
) == ["my_metric", "my SUM", "my_sql"]
assert get_metric_names(
[STR_METRIC, SIMPLE_SUM_ADHOC_METRIC, SQL_ADHOC_METRIC],
{STR_METRIC: "My Metric"},
) == ["My Metric", "my SUM", "my_sql"]


def test_get_column_name_physical_column():
assert get_column_name(STR_COLUMN) == "my_column"
assert get_metric_name(STR_COLUMN, {STR_COLUMN: "My Column"}) == "My Column"


def test_get_column_name_adhoc():
column = deepcopy(SQL_ADHOC_COLUMN)
assert get_column_name(column) == "My Adhoc Column"
assert (
get_column_name(column, {"My Adhoc Column": "My Irrelevant Mapping"})
== "My Adhoc Column"
)
del column["label"]
assert get_column_name(column) == "case when foo = 1 then 'foo' else 'bar' end"
column["label"] = ""
assert get_column_name(column) == "case when foo = 1 then 'foo' else 'bar' end"


def test_get_column_names():
assert get_column_names([STR_COLUMN, SQL_ADHOC_COLUMN]) == [
"my_column",
"My Adhoc Column",
]
assert get_column_names(
[STR_COLUMN, SQL_ADHOC_COLUMN], {"my_column": "My Column"},
) == ["My Column", "My Adhoc Column"]


def test_get_column_name_invalid_metric():
column = deepcopy(SQL_ADHOC_COLUMN)
del column["label"]
del column["sqlExpression"]
with pytest.raises(ValueError):
get_column_name(column)


def test_is_adhoc_metric():
Expand Down