Skip to content

Commit

Permalink
feat: support non-numeric columns in pivot table (apache#10389)
Browse files Browse the repository at this point in the history
* fix: support non-numeric columns in pivot table

* bump package and add unit tests

* mypy
  • Loading branch information
villebro committed Jul 28, 2020
1 parent c8a0159 commit 88bc79a
Show file tree
Hide file tree
Showing 2 changed files with 69 additions and 8 deletions.
39 changes: 31 additions & 8 deletions superset/viz.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,18 @@
from collections import defaultdict, OrderedDict
from datetime import datetime, timedelta
from itertools import product
from typing import Any, cast, Dict, List, Optional, Set, Tuple, TYPE_CHECKING, Union
from typing import (
Any,
Callable,
cast,
Dict,
List,
Optional,
Set,
Tuple,
TYPE_CHECKING,
Union,
)

import dataclasses
import geohash
Expand Down Expand Up @@ -736,6 +747,7 @@ class PivotTableViz(BaseViz):
verbose_name = _("Pivot Table")
credits = 'a <a href="https://github.com/airbnb/superset">Superset</a> original'
is_timeseries = False
enforce_numerical_metrics = False

def query_obj(self) -> QueryObjectDict:
d = super().query_obj()
Expand Down Expand Up @@ -766,29 +778,40 @@ def query_obj(self) -> QueryObjectDict:
raise QueryObjectValidationError(_("Group By' and 'Columns' can't overlap"))
return d

@staticmethod
def get_aggfunc(
metric: str, df: pd.DataFrame, form_data: Dict[str, Any]
) -> Union[str, Callable[[Any], Any]]:
aggfunc = form_data.get("pandas_aggfunc") or "sum"
if pd.api.types.is_numeric_dtype(df[metric]):
# Ensure that Pandas's sum function mimics that of SQL.
if aggfunc == "sum":
return lambda x: x.sum(min_count=1)
# only min and max work properly for non-numerics
return aggfunc if aggfunc in ("min", "max") else "max"

def get_data(self, df: pd.DataFrame) -> VizData:
if df.empty:
return None

if self.form_data.get("granularity") == "all" and DTTM_ALIAS in df:
del df[DTTM_ALIAS]

aggfunc = self.form_data.get("pandas_aggfunc") or "sum"

# Ensure that Pandas's sum function mimics that of SQL.
if aggfunc == "sum":
aggfunc = lambda x: x.sum(min_count=1)
metrics = [utils.get_metric_name(m) for m in self.form_data["metrics"]]
aggfuncs: Dict[str, Union[str, Callable[[Any], Any]]] = {}
for metric in metrics:
aggfuncs[metric] = self.get_aggfunc(metric, df, self.form_data)

groupby = self.form_data.get("groupby")
columns = self.form_data.get("columns")
if self.form_data.get("transpose_pivot"):
groupby, columns = columns, groupby
metrics = [utils.get_metric_name(m) for m in self.form_data["metrics"]]

df = df.pivot_table(
index=groupby,
columns=columns,
values=metrics,
aggfunc=aggfunc,
aggfunc=aggfuncs,
margins=self.form_data.get("pivot_margins"),
)

Expand Down
38 changes: 38 additions & 0 deletions tests/viz_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -1284,3 +1284,41 @@ def test_get_data_with_none(self):
)
data = viz.BigNumberViz(datasource, {"metrics": ["y"]}).get_data(df)
assert np.isnan(data[2]["y"])


class TestPivotTableViz(SupersetTestCase):
df = pd.DataFrame(
data={
"intcol": [1, 2, 3, None],
"floatcol": [0.1, 0.2, 0.3, None],
"strcol": ["a", "b", "c", None],
}
)

def test_get_aggfunc_numeric(self):
# is a sum function
func = viz.PivotTableViz.get_aggfunc("intcol", self.df, {})
assert hasattr(func, "__call__")
assert func(self.df["intcol"]) == 6

assert (
viz.PivotTableViz.get_aggfunc("intcol", self.df, {"pandas_aggfunc": "min"})
== "min"
)
assert (
viz.PivotTableViz.get_aggfunc(
"floatcol", self.df, {"pandas_aggfunc": "max"}
)
== "max"
)

def test_get_aggfunc_non_numeric(self):
assert viz.PivotTableViz.get_aggfunc("strcol", self.df, {}) == "max"
assert (
viz.PivotTableViz.get_aggfunc("strcol", self.df, {"pandas_aggfunc": "sum"})
== "max"
)
assert (
viz.PivotTableViz.get_aggfunc("strcol", self.df, {"pandas_aggfunc": "min"})
== "min"
)

0 comments on commit 88bc79a

Please sign in to comment.