Skip to content

Commit

Permalink
feat: add pivot v2 post-processing
Browse files Browse the repository at this point in the history
  • Loading branch information
betodealmeida committed Jul 28, 2021
1 parent 3adf8e8 commit 2b2bca3
Show file tree
Hide file tree
Showing 3 changed files with 504 additions and 1 deletion.
123 changes: 122 additions & 1 deletion superset/charts/post_processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,13 @@
from superset.utils.core import DTTM_ALIAS, extract_dataframe_dtypes, get_metric_name


def sql_like_sum(series: pd.Series) -> pd.Series:
"""
A SUM aggregation function that mimics the behavior from SQL.
"""
return series.sum(min_count=1)


def pivot_table(
result: Dict[Any, Any], form_data: Optional[Dict[str, Any]] = None
) -> Dict[Any, Any]:
Expand All @@ -53,7 +60,7 @@ def pivot_table(
aggfunc = form_data.get("pandas_aggfunc") or "sum"
if pd.api.types.is_numeric_dtype(df[metric]):
if aggfunc == "sum":
aggfunc = lambda x: x.sum(min_count=1)
aggfunc = sql_like_sum
elif aggfunc not in {"min", "max"}:
aggfunc = "max"
aggfuncs[metric] = aggfunc
Expand Down Expand Up @@ -95,6 +102,120 @@ def pivot_table(
return result


def list_unique_values(series: pd.Series) -> str:
"""
List unique values in a series.
"""
return ", ".join(set(str(v) for v in pd.Series.unique(series)))


pivot_v2_aggfunc_map = {
"Count": pd.Series.count,
"Count Unique Values": pd.Series.nunique,
"List Unique Values": list_unique_values,
"Sum": pd.Series.sum,
"Average": pd.Series.mean,
"Median": pd.Series.median,
"Sample Variance": lambda series: pd.series.var(series) if len(series) > 1 else 0,
"Sample Standard Deviation": (
lambda series: pd.series.std(series) if len(series) > 1 else 0,
),
"Minimum": pd.Series.min,
"Maximum": pd.Series.max,
"First": lambda series: series[:1],
"Last": lambda series: series[-1:],
"Sum as Fraction of Total": pd.Series.sum,
"Sum as Fraction of Rows": pd.Series.sum,
"Sum as Fraction of Columns": pd.Series.sum,
"Count as Fraction of Total": pd.Series.count,
"Count as Fraction of Rows": pd.Series.count,
"Count as Fraction of Columns": pd.Series.count,
}


def pivot_table_v2(
result: Dict[Any, Any], form_data: Optional[Dict[str, Any]] = None,
) -> Dict[Any, Any]:
"""
Pivot table v2.
"""
for query in result["queries"]:
data = query["data"]
df = pd.DataFrame(data)
form_data = form_data or {}

if form_data.get("granularity_sqla") == "all" and DTTM_ALIAS in df:
del df[DTTM_ALIAS]

# TODO (betodealmeida): implement metricsLayout
metrics = [get_metric_name(m) for m in form_data["metrics"]]
aggregate_function = form_data.get("aggregateFunction", "Sum")
groupby = form_data.get("groupbyRows") or []
columns = form_data.get("groupbyColumns") or []
if form_data.get("transposePivot"):
groupby, columns = columns, groupby

df = df.pivot_table(
index=groupby,
columns=columns,
values=metrics,
aggfunc=pivot_v2_aggfunc_map[aggregate_function],
margins=True,
)

# The pandas `pivot_table` method either brings both row/column
# totals, or none at all. We pass `margin=True` to get both, and
# remove any dimension that was not requests.
if not form_data.get("rowTotals"):
df.drop(df.columns[len(df.columns) - 1], axis=1, inplace=True)
if not form_data.get("colTotals"):
df = df[:-1]

# Compute fractions, if needed. If `colTotals` or `rowTotals` are
# present we need to adjust for including them in the sum
if aggregate_function.endswith(" as Fraction of Total"):
total = df.sum().sum()
df = df.astype(total.dtypes) / total
if form_data.get("colTotals"):
df *= 2
if form_data.get("rowTotals"):
df *= 2
elif aggregate_function.endswith(" as Fraction of Columns"):
total = df.sum(axis=0)
df = df.astype(total.dtypes).div(total, axis=1)
if form_data.get("colTotals"):
df *= 2
elif aggregate_function.endswith(" as Fraction of Rows"):
total = df.sum(axis=1)
df = df.astype(total.dtypes).div(total, axis=0)
if form_data.get("rowTotals"):
df *= 2

# Re-order the columns adhering to the metric ordering.
df = df[metrics]

# Display metrics side by side with each column
if form_data.get("combineMetric"):
df = df.stack(0).unstack().reindex(level=-1, columns=metrics)

# flatten column names
df.columns = [" ".join(column) for column in df.columns]

# re-arrange data into a list of dicts
data = []
for i in df.index:
row = {col: df[col][i] for col in df.columns}
row[df.index.name] = i
data.append(row)
query["data"] = data
query["colnames"] = list(df.columns)
query["coltypes"] = extract_dataframe_dtypes(df)
query["rowcount"] = len(df.index)

return result


post_processors = {
"pivot_table": pivot_table,
"pivot_table_v2": pivot_table_v2,
}
16 changes: 16 additions & 0 deletions tests/unit_tests/charts/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
Loading

0 comments on commit 2b2bca3

Please sign in to comment.