From 9b51ab2ae6d16b4d599b99fd253aba00c959b41d Mon Sep 17 00:00:00 2001 From: Patrick Hoefler <61934744+phofl@users.noreply.github.com> Date: Sun, 10 Dec 2023 20:11:13 +0100 Subject: [PATCH] ENH: Make get_dummies return ea booleans for ea inputs (#56291) * ENH: Make get_dummies return ea booleans for ea inputs * ENH: Make get_dummies return ea booleans for ea inputs * Update * Update pandas/tests/reshape/test_get_dummies.py Co-authored-by: Thomas Baumann * Update test_get_dummies.py * Update test_get_dummies.py * Fixup --------- Co-authored-by: Thomas Baumann --- doc/source/whatsnew/v2.2.0.rst | 1 + pandas/core/reshape/encoding.py | 24 ++++++++++++- pandas/tests/reshape/test_get_dummies.py | 45 ++++++++++++++++++++++++ 3 files changed, 69 insertions(+), 1 deletion(-) diff --git a/doc/source/whatsnew/v2.2.0.rst b/doc/source/whatsnew/v2.2.0.rst index ad44e87cacf82..acec8379ee5b3 100644 --- a/doc/source/whatsnew/v2.2.0.rst +++ b/doc/source/whatsnew/v2.2.0.rst @@ -218,6 +218,7 @@ Other enhancements - :meth:`~DataFrame.to_sql` with method parameter set to ``multi`` works with Oracle on the backend - :attr:`Series.attrs` / :attr:`DataFrame.attrs` now uses a deepcopy for propagating ``attrs`` (:issue:`54134`). +- :func:`get_dummies` now returning extension dtypes ``boolean`` or ``bool[pyarrow]`` that are compatible with the input dtype (:issue:`56273`) - :func:`read_csv` now supports ``on_bad_lines`` parameter with ``engine="pyarrow"``. (:issue:`54480`) - :func:`read_sas` returns ``datetime64`` dtypes with resolutions better matching those stored natively in SAS, and avoids returning object-dtype in cases that cannot be stored with ``datetime64[ns]`` dtype (:issue:`56127`) - :func:`read_spss` now returns a :class:`DataFrame` that stores the metadata in :attr:`DataFrame.attrs`. (:issue:`54264`) diff --git a/pandas/core/reshape/encoding.py b/pandas/core/reshape/encoding.py index 6963bf677bcfb..3ed67bb7b7c02 100644 --- a/pandas/core/reshape/encoding.py +++ b/pandas/core/reshape/encoding.py @@ -21,9 +21,14 @@ is_object_dtype, pandas_dtype, ) +from pandas.core.dtypes.dtypes import ( + ArrowDtype, + CategoricalDtype, +) from pandas.core.arrays import SparseArray from pandas.core.arrays.categorical import factorize_from_iterable +from pandas.core.arrays.string_ import StringDtype from pandas.core.frame import DataFrame from pandas.core.indexes.api import ( Index, @@ -244,8 +249,25 @@ def _get_dummies_1d( # Series avoids inconsistent NaN handling codes, levels = factorize_from_iterable(Series(data, copy=False)) - if dtype is None: + if dtype is None and hasattr(data, "dtype"): + input_dtype = data.dtype + if isinstance(input_dtype, CategoricalDtype): + input_dtype = input_dtype.categories.dtype + + if isinstance(input_dtype, ArrowDtype): + import pyarrow as pa + + dtype = ArrowDtype(pa.bool_()) # type: ignore[assignment] + elif ( + isinstance(input_dtype, StringDtype) + and input_dtype.storage != "pyarrow_numpy" + ): + dtype = pandas_dtype("boolean") # type: ignore[assignment] + else: + dtype = np.dtype(bool) + elif dtype is None: dtype = np.dtype(bool) + _dtype = pandas_dtype(dtype) if is_object_dtype(_dtype): diff --git a/pandas/tests/reshape/test_get_dummies.py b/pandas/tests/reshape/test_get_dummies.py index 3bfff56cfedf2..9b7aefac60969 100644 --- a/pandas/tests/reshape/test_get_dummies.py +++ b/pandas/tests/reshape/test_get_dummies.py @@ -4,13 +4,18 @@ import numpy as np import pytest +import pandas.util._test_decorators as td + from pandas.core.dtypes.common import is_integer_dtype import pandas as pd from pandas import ( + ArrowDtype, Categorical, + CategoricalDtype, CategoricalIndex, DataFrame, + Index, RangeIndex, Series, SparseDtype, @@ -19,6 +24,11 @@ import pandas._testing as tm from pandas.core.arrays.sparse import SparseArray +try: + import pyarrow as pa +except ImportError: + pa = None + class TestGetDummies: @pytest.fixture @@ -217,6 +227,7 @@ def test_dataframe_dummies_string_dtype(self, df): }, dtype=bool, ) + expected[["B_b", "B_c"]] = expected[["B_b", "B_c"]].astype("boolean") tm.assert_frame_equal(result, expected) def test_dataframe_dummies_mix_default(self, df, sparse, dtype): @@ -693,3 +704,37 @@ def test_get_dummies_ea_dtype_dataframe(self, any_numeric_ea_and_arrow_dtype): dtype=any_numeric_ea_and_arrow_dtype, ) tm.assert_frame_equal(result, expected) + + @td.skip_if_no("pyarrow") + def test_get_dummies_ea_dtype(self): + # GH#56273 + for dtype, exp_dtype in [ + ("string[pyarrow]", "boolean"), + ("string[pyarrow_numpy]", "bool"), + (CategoricalDtype(Index(["a"], dtype="string[pyarrow]")), "boolean"), + (CategoricalDtype(Index(["a"], dtype="string[pyarrow_numpy]")), "bool"), + ]: + df = DataFrame({"name": Series(["a"], dtype=dtype), "x": 1}) + result = get_dummies(df) + expected = DataFrame({"x": 1, "name_a": Series([True], dtype=exp_dtype)}) + tm.assert_frame_equal(result, expected) + + @td.skip_if_no("pyarrow") + def test_get_dummies_arrow_dtype(self): + # GH#56273 + df = DataFrame({"name": Series(["a"], dtype=ArrowDtype(pa.string())), "x": 1}) + result = get_dummies(df) + expected = DataFrame({"x": 1, "name_a": Series([True], dtype="bool[pyarrow]")}) + tm.assert_frame_equal(result, expected) + + df = DataFrame( + { + "name": Series( + ["a"], + dtype=CategoricalDtype(Index(["a"], dtype=ArrowDtype(pa.string()))), + ), + "x": 1, + } + ) + result = get_dummies(df) + tm.assert_frame_equal(result, expected)