Skip to content

Commit

Permalink
[python-package] consolidate pandas-to-numpy conversion code (#6156)
Browse files Browse the repository at this point in the history
  • Loading branch information
jameslamb committed Nov 16, 2023
1 parent e63e54a commit 18dbd65
Showing 1 changed file with 27 additions and 26 deletions.
53 changes: 27 additions & 26 deletions python-package/lightgbm/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -758,6 +758,23 @@ def _check_for_bad_pandas_dtypes(pandas_dtypes_series: pd_Series) -> None:
f'Fields with bad pandas dtypes: {", ".join(bad_pandas_dtypes)}')


def _pandas_to_numpy(
data: pd_DataFrame,
target_dtype: "np.typing.DTypeLike"
) -> np.ndarray:
_check_for_bad_pandas_dtypes(data.dtypes)
try:
# most common case (no nullable dtypes)
return data.to_numpy(dtype=target_dtype, copy=False)
except TypeError:
# 1.0 <= pd version < 1.1 and nullable dtypes, least common case
# raises error because array is casted to type(pd.NA) and there's no na_value argument
return data.astype(target_dtype, copy=False).values
except ValueError:
# data has nullable dtypes, but we can specify na_value argument and copy will be made
return data.to_numpy(dtype=target_dtype, na_value=np.nan)


def _data_from_pandas(
data: pd_DataFrame,
feature_name: _LGBM_FeatureNameConfiguration,
Expand Down Expand Up @@ -790,22 +807,17 @@ def _data_from_pandas(
else: # use cat cols specified by user
categorical_feature = list(categorical_feature) # type: ignore[assignment]

# get numpy representation of the data
_check_for_bad_pandas_dtypes(data.dtypes)
df_dtypes = [dtype.type for dtype in data.dtypes]
df_dtypes.append(np.float32) # so that the target dtype considers floats
# so that the target dtype considers floats
df_dtypes.append(np.float32)
target_dtype = np.result_type(*df_dtypes)
try:
# most common case (no nullable dtypes)
data = data.to_numpy(dtype=target_dtype, copy=False)
except TypeError:
# 1.0 <= pd version < 1.1 and nullable dtypes, least common case
# raises error because array is casted to type(pd.NA) and there's no na_value argument
data = data.astype(target_dtype, copy=False).values
except ValueError:
# data has nullable dtypes, but we can specify na_value argument and copy will be made
data = data.to_numpy(dtype=target_dtype, na_value=np.nan)
return data, feature_name, categorical_feature, pandas_categorical

return (
_pandas_to_numpy(data, target_dtype=target_dtype),
feature_name,
categorical_feature,
pandas_categorical
)


def _dump_pandas_categorical(
Expand Down Expand Up @@ -2805,18 +2817,7 @@ def set_label(self, label: Optional[_LGBM_LabelType]) -> "Dataset":
if isinstance(label, pd_DataFrame):
if len(label.columns) > 1:
raise ValueError('DataFrame for label cannot have multiple columns')
_check_for_bad_pandas_dtypes(label.dtypes)
try:
# most common case (no nullable dtypes)
label = label.to_numpy(dtype=np.float32, copy=False)
except TypeError:
# 1.0 <= pd version < 1.1 and nullable dtypes, least common case
# raises error because array is casted to type(pd.NA) and there's no na_value argument
label = label.astype(np.float32, copy=False).values
except ValueError:
# data has nullable dtypes, but we can specify na_value argument and copy will be made
label = label.to_numpy(dtype=np.float32, na_value=np.nan)
label_array = np.ravel(label)
label_array = np.ravel(_pandas_to_numpy(label, target_dtype=np.float32))
elif _is_pyarrow_array(label):
label_array = label
else:
Expand Down

0 comments on commit 18dbd65

Please sign in to comment.