Skip to content

Commit

Permalink
Implement numeric_only=False for GroupBy.corr and ``GroupBy.c…
Browse files Browse the repository at this point in the history
…ov`` (dask#10264)
  • Loading branch information
phofl committed May 8, 2023
1 parent 980fd92 commit e9845aa
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 4 deletions.
9 changes: 5 additions & 4 deletions dask/dataframe/groupby.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
GROUP_KEYS_DEFAULT,
DataFrame,
Series,
_convert_to_numeric,
_extract_meta,
_Frame,
aca,
Expand Down Expand Up @@ -90,8 +91,6 @@
)

NUMERIC_ONLY_NOT_IMPLEMENTED = [
"corr",
"cov",
"cumprod",
"cumsum",
"mean",
Expand Down Expand Up @@ -667,6 +666,10 @@ def _cov_chunk(df, *by, numeric_only=no_default):
if is_series_like(df):
df = df.to_frame()
df = df.copy()
if numeric_only is False:
dt_df = df.select_dtypes(include=["datetime", "timedelta"])
for col in dt_df.columns:
df[col] = _convert_to_numeric(dt_df[col], True)

# mapping columns to str(numerical) values allows us to easily handle
# arbitrary column names (numbers, string, empty strings)
Expand Down Expand Up @@ -2101,7 +2104,6 @@ def std(self, ddof=1, split_every=None, split_out=1, numeric_only=no_default):
return result

@derived_from(pd.DataFrame)
@numeric_only_not_implemented
def corr(self, ddof=1, split_every=None, split_out=1, numeric_only=no_default):
"""Groupby correlation:
corr(X, Y) = cov(X, Y) / (std_x * std_y)
Expand All @@ -2116,7 +2118,6 @@ def corr(self, ddof=1, split_every=None, split_out=1, numeric_only=no_default):
)

@derived_from(pd.DataFrame)
@numeric_only_not_implemented
def cov(
self, ddof=1, split_every=None, split_out=1, std=False, numeric_only=no_default
):
Expand Down
21 changes: 21 additions & 0 deletions dask/dataframe/tests/test_groupby.py
Original file line number Diff line number Diff line change
Expand Up @@ -3634,3 +3634,24 @@ def test_groupby_numeric_only_true(func):
ddf_result = getattr(ddf.groupby("A"), func)(numeric_only=True)
pdf_result = getattr(df.groupby("A"), func)(numeric_only=True)
assert_eq(ddf_result, pdf_result)


@pytest.mark.skipif(not PANDAS_GT_150, reason="numeric_only not supported for <1.5")
@pytest.mark.parametrize("func", ["cov", "corr"])
def test_groupby_numeric_only_false_cov_corr(func):
df = pd.DataFrame(
{
"float": [1.0, 2.0, 3.0, 4.0, 5, 6.0, 7.0, 8.0],
"int": [1, 2, 3, 4, 5, 6, 7, 8],
"timedelta": pd.to_timedelta([1, 2, 3, 4, 5, 6, 7, 8]),
"A": 1,
}
)
ddf = dd.from_pandas(df, npartitions=2)
dd_result = getattr(ddf.groupby("A"), func)(numeric_only=False)
pd_result = getattr(df.groupby("A"), func)(numeric_only=False)
assert_eq(dd_result, pd_result)

dd_result = getattr(ddf.groupby("A"), func)(numeric_only=True)
pd_result = getattr(df.groupby("A"), func)(numeric_only=True)
assert_eq(dd_result, pd_result)

0 comments on commit e9845aa

Please sign in to comment.