Skip to content

Commit

Permalink
[SPARK-43295][PS] Support string type columns for DataFrameGroupBy.sum
Browse files Browse the repository at this point in the history
### What changes were proposed in this pull request?

This PR proposes to support string type columns for `DataFrameGroupBy.sum`.

### Why are the changes needed?

To match the behavior with latest pandas.

### Does this PR introduce _any_ user-facing change?

Yes, from now on the `DataFrameGroupBy.sum` follows the behavior of latest pandas as below:

**Test DataFrame**
```python
>>> psdf
   A    B  C      D
0  1  3.1  a   True
1  2  4.1  b  False
2  1  4.1  b  False
3  2  3.1  a   True
```

**Before**
```python
>>> psdf.groupby("A").sum().sort_index()
     B  D
A
1  7.2  1
2  7.2  1
```

**After**
```python
>>> psdf.groupby("A").sum().sort_index()
     B   C  D
A
1  7.2  ab  1
2  7.2  ba  1
```

### How was this patch tested?

Updated the existing UTs to support string type columns.

### Was this patch authored or co-authored using generative AI tooling?

No.

Closes apache#42798 from itholic/SPARK-43295.

Authored-by: Haejoon Lee <haejoon.lee@databricks.com>
Signed-off-by: Ruifeng Zheng <ruifengz@apache.org>
  • Loading branch information
itholic authored and zhengruifeng committed Sep 11, 2023
1 parent eb0b09f commit 3d119a5
Show file tree
Hide file tree
Showing 4 changed files with 28 additions and 18 deletions.
32 changes: 24 additions & 8 deletions python/pyspark/pandas/groupby.py
Original file line number Diff line number Diff line change
Expand Up @@ -857,10 +857,10 @@ def sum(self, numeric_only: Optional[bool] = True, min_count: int = 0) -> FrameL
... "C": [3, 4, 3, 4], "D": ["a", "a", "b", "a"]})
>>> df.groupby("A").sum().sort_index()
B C
B C D
A
1 1 6
2 1 8
1 1 6 ab
2 1 8 aa
>>> df.groupby("D").sum().sort_index()
A B C
Expand Down Expand Up @@ -900,17 +900,17 @@ def sum(self, numeric_only: Optional[bool] = True, min_count: int = 0) -> FrameL
unsupported = [
col.name
for col in self._agg_columns
if not isinstance(col.spark.data_type, (NumericType, BooleanType))
if not isinstance(col.spark.data_type, (NumericType, BooleanType, StringType))
]
if len(unsupported) > 0:
log_advice(
"GroupBy.sum() can only support numeric and bool columns even if"
"GroupBy.sum() can only support numeric, bool and string columns even if"
f"numeric_only=False, skip unsupported columns: {unsupported}"
)

return self._reduce_for_stat_function(
F.sum,
accepted_spark_types=(NumericType, BooleanType),
accepted_spark_types=(NumericType, BooleanType, StringType),
bool_to_numeric=True,
min_count=min_count,
)
Expand Down Expand Up @@ -3534,7 +3534,21 @@ def _reduce_for_stat_function(
for label in psdf._internal.column_labels:
psser = psdf._psser_for(label)
input_scol = psser._dtype_op.nan_to_null(psser).spark.column
output_scol = sfun(input_scol)
if sfun.__name__ == "sum" and isinstance(
psdf._internal.spark_type_for(label), StringType
):
input_scol_name = psser._internal.data_spark_column_names[0]
# Sort data with natural order column to ensure order of data
sorted_array = F.array_sort(
F.collect_list(F.struct(NATURAL_ORDER_COLUMN_NAME, input_scol))
)

# Using transform to extract strings
output_scol = F.concat_ws(
"", F.transform(sorted_array, lambda x: x.getField(input_scol_name))
)
else:
output_scol = sfun(input_scol)

if min_count > 0:
output_scol = F.when(
Expand Down Expand Up @@ -3591,7 +3605,9 @@ def _prepare_reduce(
):
agg_columns.append(psser)
sdf = self._psdf._internal.spark_frame.select(
*groupkey_scols, *[psser.spark.column for psser in agg_columns]
*groupkey_scols,
*[psser.spark.column for psser in agg_columns],
NATURAL_ORDER_COLUMN_NAME,
)
internal = InternalFrame(
spark_frame=sdf,
Expand Down
6 changes: 0 additions & 6 deletions python/pyspark/pandas/tests/groupby/test_groupby.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,9 +59,6 @@ def test_groupby_simple(self):
},
index=[0, 1, 3, 5, 6, 8, 9, 9, 9],
)
if LooseVersion(pd.__version__) >= LooseVersion("2.0.0"):
# TODO(SPARK-43295): Make DataFrameGroupBy.sum support for string type columns
pdf = pdf[["a", "b", "c", "e"]]
psdf = ps.from_pandas(pdf)

for as_index in [True, False]:
Expand Down Expand Up @@ -180,9 +177,6 @@ def sort(df):
index=[0, 1, 3, 5, 6, 8, 9, 9, 9],
)
psdf = ps.from_pandas(pdf)
if LooseVersion(pd.__version__) >= LooseVersion("2.0.0"):
# TODO(SPARK-43295): Make DataFrameGroupBy.sum support for string type columns
pdf = pdf[[10, 20, 30]]

for as_index in [True, False]:
if as_index:
Expand Down
2 changes: 1 addition & 1 deletion python/pyspark/pandas/tests/groupby/test_stat.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ def test_basic_stat_funcs(self):
# self._test_stat_func(lambda groupby_obj: groupby_obj.sum(), check_exact=False)
self.assert_eq(
psdf.groupby("A").sum().sort_index(),
pdf.groupby("A").sum(numeric_only=True).sort_index(),
pdf.groupby("A").sum().sort_index(),
check_exact=False,
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ def sort(df):

self.assert_eq(
sort(psdf1.groupby(psdf2.a, as_index=as_index).sum()),
sort(pdf1.groupby(pdf2.a, as_index=as_index).sum(numeric_only=True)),
sort(pdf1.groupby(pdf2.a, as_index=as_index).sum()),
almost=as_index,
)

Expand All @@ -93,7 +93,7 @@ def test_groupby_multiindex_columns(self):

self.assert_eq(
psdf1.groupby(psdf2[("x", "a")]).sum().sort_index(),
pdf1.groupby(pdf2[("x", "a")]).sum(numeric_only=True).sort_index(),
pdf1.groupby(pdf2[("x", "a")]).sum().sort_index(),
)

self.assert_eq(
Expand All @@ -102,7 +102,7 @@ def test_groupby_multiindex_columns(self):
.sort_values(("y", "c"))
.reset_index(drop=True),
pdf1.groupby(pdf2[("x", "a")], as_index=False)
.sum(numeric_only=True)
.sum()
.sort_values(("y", "c"))
.reset_index(drop=True),
)
Expand Down

0 comments on commit 3d119a5

Please sign in to comment.