Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(chart-data-api): make pivoted columns flattenable #10255

Merged
merged 2 commits into from
Jul 8, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
2 changes: 0 additions & 2 deletions superset/charts/schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -414,8 +414,6 @@ class ChartDataPivotOptionsSchema(ChartDataPostProcessingOperationOptionsSchema)
fields.String(
allow_none=False, description="Columns to group by on the table columns",
),
minLength=1,
required=True,
)
metric_fill_value = fields.Number(
description="Value to replace missing values with in aggregate calculations.",
Expand Down
42 changes: 36 additions & 6 deletions superset/utils/pandas_postprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,13 +72,38 @@
)


def _flatten_column_after_pivot(
column: Union[str, Tuple[str, ...]], aggregates: Dict[str, Dict[str, Any]]
) -> str:
"""
Function for flattening column names into a single string. This step is necessary
to be able to properly serialize a DataFrame. If the column is a string, return
element unchanged. For multi-element columns, join column elements with a comma,
with the exception of pivots made with a single aggregate, in which case the
aggregate column name is omitted.

:param column: single element from `DataFrame.columns`
:param aggregates: aggregates
:return:
"""
if isinstance(column, str):
return column
if len(column) == 1:
return column[0]
if len(aggregates) == 1 and len(column) > 1:
# drop aggregate for single aggregate pivots with multiple groupings
# from column name (aggregates always come first in column name)
column = column[1:]
return ", ".join(column)


def validate_column_args(*argnames: str) -> Callable[..., Any]:
def wrapper(func: Callable[..., Any]) -> Callable[..., Any]:
def wrapped(df: DataFrame, **options: Any) -> Any:
columns = df.columns.tolist()
for name in argnames:
if name in options and not all(
elem in columns for elem in options[name]
elem in columns for elem in options.get(name) or []
):
raise QueryObjectValidationError(
_("Referenced columns not available in DataFrame.")
Expand Down Expand Up @@ -154,14 +179,15 @@ def _append_columns(
def pivot( # pylint: disable=too-many-arguments
df: DataFrame,
index: List[str],
columns: List[str],
aggregates: Dict[str, Dict[str, Any]],
columns: Optional[List[str]] = None,
metric_fill_value: Optional[Any] = None,
column_fill_value: Optional[str] = None,
drop_missing_columns: Optional[bool] = True,
combine_value_with_metric: bool = False,
marginal_distributions: Optional[bool] = None,
marginal_distribution_name: Optional[str] = None,
flatten_columns: bool = True,
) -> DataFrame:
"""
Perform a pivot operation on a DataFrame.
Expand All @@ -179,17 +205,14 @@ def pivot( # pylint: disable=too-many-arguments
:param marginal_distributions: Add totals for row/column. Default to False
:param marginal_distribution_name: Name of row/column with marginal distribution.
Default to 'All'.
:param flatten_columns: Convert column names to strings
:return: A pivot table
:raises ChartDataValidationError: If the request in incorrect
"""
if not index:
raise QueryObjectValidationError(
_("Pivot operation requires at least one index")
)
if not columns:
raise QueryObjectValidationError(
_("Pivot operation requires at least one column")
)
if not aggregates:
raise QueryObjectValidationError(
_("Pivot operation must include at least one aggregate")
Expand Down Expand Up @@ -218,6 +241,13 @@ def pivot( # pylint: disable=too-many-arguments
if combine_value_with_metric:
df = df.stack(0).unstack()

# Make index regular column
if flatten_columns:
df.columns = [
_flatten_column_after_pivot(col, aggregates) for col in df.columns
]
# return index as regular column
df.reset_index(level=0, inplace=True)
return df


Expand Down
111 changes: 98 additions & 13 deletions tests/pandas_postprocessing_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,12 @@
from .base_tests import SupersetTestCase
from .fixtures.dataframes import categories_df, lonlat_df, timeseries_df

AGGREGATES_SINGLE = {"idx_nulls": {"operator": "sum"}}
AGGREGATES_MULTIPLE = {
"idx_nulls": {"operator": "sum"},
"asc_idx": {"operator": "mean"},
}


def series_to_list(series: Series) -> List[Any]:
"""
Expand Down Expand Up @@ -57,41 +63,120 @@ def round_floats(


class TestPostProcessing(SupersetTestCase):
def test_pivot(self):
aggregates = {"idx_nulls": {"operator": "sum"}}
def test_flatten_column_after_pivot(self):
"""
Test pivot column flattening function
"""
# single aggregate cases
self.assertEqual(
proc._flatten_column_after_pivot(
aggregates=AGGREGATES_SINGLE, column="idx_nulls",
),
"idx_nulls",
)
self.assertEqual(
proc._flatten_column_after_pivot(
aggregates=AGGREGATES_SINGLE, column=("idx_nulls", "col1"),
),
"col1",
)
self.assertEqual(
proc._flatten_column_after_pivot(
aggregates=AGGREGATES_SINGLE, column=("idx_nulls", "col1", "col2"),
),
"col1, col2",
)

# Multiple aggregate cases
self.assertEqual(
proc._flatten_column_after_pivot(
aggregates=AGGREGATES_MULTIPLE, column=("idx_nulls", "asc_idx", "col1"),
),
"idx_nulls, asc_idx, col1",
)
self.assertEqual(
proc._flatten_column_after_pivot(
aggregates=AGGREGATES_MULTIPLE,
column=("idx_nulls", "asc_idx", "col1", "col2"),
),
"idx_nulls, asc_idx, col1, col2",
)

def test_pivot_without_columns(self):
"""
Make sure pivot without columns returns correct DataFrame
"""
df = proc.pivot(df=categories_df, index=["name"], aggregates=AGGREGATES_SINGLE,)
self.assertListEqual(
df.columns.tolist(), ["name", "idx_nulls"],
)
self.assertEqual(len(df), 101)
self.assertEqual(df.sum()[1], 1050)

# regular pivot
def test_pivot_with_single_column(self):
"""
Make sure pivot with single column returns correct DataFrame
"""
df = proc.pivot(
df=categories_df,
index=["name"],
columns=["category"],
aggregates=aggregates,
aggregates=AGGREGATES_SINGLE,
)
self.assertListEqual(
df.columns.tolist(),
[("idx_nulls", "cat0"), ("idx_nulls", "cat1"), ("idx_nulls", "cat2")],
df.columns.tolist(), ["name", "cat0", "cat1", "cat2"],
)
self.assertEqual(len(df), 101)
self.assertEqual(df.sum()[0], 315)
self.assertEqual(df.sum()[1], 315)

# regular pivot
df = proc.pivot(
df=categories_df,
index=["dept"],
columns=["category"],
aggregates=aggregates,
aggregates=AGGREGATES_SINGLE,
)
self.assertListEqual(
df.columns.tolist(), ["dept", "cat0", "cat1", "cat2"],
)
self.assertEqual(len(df), 5)

# fill value
def test_pivot_with_multiple_columns(self):
"""
Make sure pivot with multiple columns returns correct DataFrame
"""
df = proc.pivot(
df=categories_df,
index=["name"],
columns=["category", "dept"],
aggregates=AGGREGATES_SINGLE,
)
self.assertEqual(len(df.columns), 1 + 3 * 5) # index + possible permutations

def test_pivot_fill_values(self):
"""
Make sure pivot with fill values returns correct DataFrame
"""
df = proc.pivot(
df=categories_df,
index=["name"],
columns=["category"],
metric_fill_value=1,
aggregates={"idx_nulls": {"operator": "sum"}},
)
self.assertEqual(df.sum()[0], 382)
self.assertEqual(df.sum()[1], 382)

def test_pivot_exceptions(self):
"""
Make sure pivot raises correct Exceptions
"""
# Missing index
self.assertRaises(
TypeError,
proc.pivot,
df=categories_df,
columns=["dept"],
aggregates=AGGREGATES_SINGLE,
)

# invalid index reference
self.assertRaises(
Expand All @@ -100,7 +185,7 @@ def test_pivot(self):
df=categories_df,
index=["abc"],
columns=["dept"],
aggregates=aggregates,
aggregates=AGGREGATES_SINGLE,
)

# invalid column reference
Expand All @@ -110,7 +195,7 @@ def test_pivot(self):
df=categories_df,
index=["dept"],
columns=["abc"],
aggregates=aggregates,
aggregates=AGGREGATES_SINGLE,
)

# invalid aggregate options
Expand Down