Skip to content

Commit

Permalink
[python-package] Allow to pass Arrow array as labels (#6163)
Browse files Browse the repository at this point in the history
  • Loading branch information
borchero committed Nov 7, 2023
1 parent 1600422 commit b7f6311
Show file tree
Hide file tree
Showing 8 changed files with 166 additions and 17 deletions.
17 changes: 17 additions & 0 deletions include/LightGBM/c_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -555,6 +555,23 @@ LIGHTGBM_C_EXPORT int LGBM_DatasetSetField(DatasetHandle handle,
int num_element,
int type);

/*!
* \brief Set vector to a content in info.
* \note
* - \a label convert input datatype into ``float32``.
* \param handle Handle of dataset
* \param field_name Field name, can be \a label
* \param n_chunks The number of Arrow arrays passed to this function
* \param chunks Pointer to the list of Arrow arrays
* \param schema Pointer to the schema of all Arrow arrays
* \return 0 when succeed, -1 when failure happens
*/
LIGHTGBM_C_EXPORT int LGBM_DatasetSetFieldFromArrow(DatasetHandle handle,
const char* field_name,
int64_t n_chunks,
const ArrowArray* chunks,
const ArrowSchema* schema);

/*!
* \brief Get info vector from dataset.
* \param handle Handle of dataset
Expand Down
6 changes: 6 additions & 0 deletions include/LightGBM/dataset.h
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,7 @@ class Metadata {
const std::vector<data_size_t>& used_data_indices);

void SetLabel(const label_t* label, data_size_t len);
void SetLabel(const ArrowChunkedArray& array);

void SetWeights(const label_t* weights, data_size_t len);

Expand Down Expand Up @@ -334,6 +335,9 @@ class Metadata {
void CalculateQueryBoundaries();
/*! \brief Insert labels at the given index */
void InsertLabels(const label_t* labels, data_size_t start_index, data_size_t len);
/*! \brief Set labels from pointers to the first element and the end of an iterator. */
template <typename It>
void SetLabelsFromIterator(It first, It last);
/*! \brief Insert weights at the given index */
void InsertWeights(const label_t* weights, data_size_t start_index, data_size_t len);
/*! \brief Insert initial scores at the given index */
Expand Down Expand Up @@ -655,6 +659,8 @@ class Dataset {

LIGHTGBM_EXPORT void FinishLoad();

bool SetFieldFromArrow(const char* field_name, const ArrowChunkedArray& ca);

LIGHTGBM_EXPORT bool SetFloatField(const char* field_name, const float* field_data, data_size_t num_element);

LIGHTGBM_EXPORT bool SetDoubleField(const char* field_name, const double* field_data, data_size_t num_element);
Expand Down
45 changes: 36 additions & 9 deletions python-package/lightgbm/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
import scipy.sparse

from .compat import (PANDAS_INSTALLED, PYARROW_INSTALLED, arrow_cffi, arrow_is_floating, arrow_is_integer, concat,
dt_DataTable, pa_Table, pd_CategoricalDtype, pd_DataFrame, pd_Series)
dt_DataTable, pa_Array, pa_ChunkedArray, pa_Table, pd_CategoricalDtype, pd_DataFrame, pd_Series)
from .libpath import find_lib_path

if TYPE_CHECKING:
Expand Down Expand Up @@ -99,7 +99,9 @@
List[int],
np.ndarray,
pd_Series,
pd_DataFrame
pd_DataFrame,
pa_Array,
pa_ChunkedArray,
]
_LGBM_PredictDataType = Union[
str,
Expand Down Expand Up @@ -353,6 +355,11 @@ def _is_2d_collection(data: Any) -> bool:
)


def _is_pyarrow_array(data: Any) -> bool:
"""Check whether data is a PyArrow array."""
return isinstance(data, (pa_Array, pa_ChunkedArray))


def _is_pyarrow_table(data: Any) -> bool:
"""Check whether data is a PyArrow table."""
return isinstance(data, pa_Table)
Expand Down Expand Up @@ -384,7 +391,11 @@ def schema_ptr(self) -> int:
def _export_arrow_to_c(data: pa_Table) -> _ArrowCArray:
"""Export an Arrow type to its C representation."""
# Obtain objects to export
if isinstance(data, pa_Table):
if isinstance(data, pa_Array):
export_objects = [data]
elif isinstance(data, pa_ChunkedArray):
export_objects = data.chunks
elif isinstance(data, pa_Table):
export_objects = data.to_batches()
else:
raise ValueError(f"data of type '{type(data)}' cannot be exported to Arrow")
Expand Down Expand Up @@ -1620,7 +1631,7 @@ def __init__(
data : str, pathlib.Path, numpy array, pandas DataFrame, H2O DataTable's Frame, scipy.sparse, Sequence, list of Sequence, list of numpy array or pyarrow Table
Data source of Dataset.
If str or pathlib.Path, it represents the path to a text file (CSV, TSV, or LibSVM) or a LightGBM Dataset binary file.
label : list, numpy 1-D array, pandas Series / one-column DataFrame or None, optional (default=None)
label : list, numpy 1-D array, pandas Series / one-column DataFrame, pyarrow Array, pyarrow ChunkedArray or None, optional (default=None)
Label of the data.
reference : Dataset or None, optional (default=None)
If this is Dataset for validation, training data should be used as reference.
Expand Down Expand Up @@ -2402,7 +2413,7 @@ def create_valid(
data : str, pathlib.Path, numpy array, pandas DataFrame, H2O DataTable's Frame, scipy.sparse, Sequence, list of Sequence or list of numpy array
Data source of Dataset.
If str or pathlib.Path, it represents the path to a text file (CSV, TSV, or LibSVM) or a LightGBM Dataset binary file.
label : list, numpy 1-D array, pandas Series / one-column DataFrame or None, optional (default=None)
label : list, numpy 1-D array, pandas Series / one-column DataFrame, pyarrow Array, pyarrow ChunkedArray or None, optional (default=None)
Label of the data.
weight : list, numpy 1-D array, pandas Series or None, optional (default=None)
Weight for each instance. Weights should be non-negative.
Expand Down Expand Up @@ -2519,15 +2530,15 @@ def _reverse_update_params(self) -> "Dataset":
def set_field(
self,
field_name: str,
data: Optional[Union[List[List[float]], List[List[int]], List[float], List[int], np.ndarray, pd_Series, pd_DataFrame]]
data: Optional[Union[List[List[float]], List[List[int]], List[float], List[int], np.ndarray, pd_Series, pd_DataFrame, pa_Array, pa_ChunkedArray]]
) -> "Dataset":
"""Set property into the Dataset.
Parameters
----------
field_name : str
The field name of the information.
data : list, list of lists (for multi-class task), numpy array, pandas Series, pandas DataFrame (for multi-class task), or None
data : list, list of lists (for multi-class task), numpy array, pandas Series, pandas DataFrame (for multi-class task), pyarrow Array, pyarrow ChunkedArray or None
The data to be set.
Returns
Expand All @@ -2546,6 +2557,20 @@ def set_field(
ctypes.c_int(0),
ctypes.c_int(_FIELD_TYPE_MAPPER[field_name])))
return self

# If the data is a arrow data, we can just pass it to C
if _is_pyarrow_array(data):
c_array = _export_arrow_to_c(data)
_safe_call(_LIB.LGBM_DatasetSetFieldFromArrow(
self._handle,
_c_str(field_name),
ctypes.c_int64(c_array.n_chunks),
ctypes.c_void_p(c_array.chunks_ptr),
ctypes.c_void_p(c_array.schema_ptr),
))
self.version += 1
return self

dtype: "np.typing.DTypeLike"
if field_name == 'init_score':
dtype = np.float64
Expand Down Expand Up @@ -2749,7 +2774,7 @@ def set_label(self, label: Optional[_LGBM_LabelType]) -> "Dataset":
Parameters
----------
label : list, numpy 1-D array, pandas Series / one-column DataFrame or None
label : list, numpy 1-D array, pandas Series / one-column DataFrame, pyarrow Array, pyarrow ChunkedArray or None
The label information to be set into Dataset.
Returns
Expand All @@ -2774,6 +2799,8 @@ def set_label(self, label: Optional[_LGBM_LabelType]) -> "Dataset":
# data has nullable dtypes, but we can specify na_value argument and copy will be made
label = label.to_numpy(dtype=np.float32, na_value=np.nan)
label_array = np.ravel(label)
elif _is_pyarrow_array(label):
label_array = label
else:
label_array = _list_to_1d_numpy(label, dtype=np.float32, name='label')
self.set_field('label', label_array)
Expand Down Expand Up @@ -4353,7 +4380,7 @@ def refit(
data : str, pathlib.Path, numpy array, pandas DataFrame, H2O DataTable's Frame, scipy.sparse, Sequence, list of Sequence or list of numpy array
Data source for refit.
If str or pathlib.Path, it represents the path to a text file (CSV, TSV, or LibSVM).
label : list, numpy 1-D array or pandas Series / one-column DataFrame
label : list, numpy 1-D array, pandas Series / one-column DataFrame, pyarrow Array or pyarrow ChunkedArray
Label for refit.
decay_rate : float, optional (default=0.9)
Decay rate of refit,
Expand Down
14 changes: 14 additions & 0 deletions python-package/lightgbm/compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,8 @@ def __init__(self, *args, **kwargs):

"""pyarrow"""
try:
from pyarrow import Array as pa_Array
from pyarrow import ChunkedArray as pa_ChunkedArray
from pyarrow import Table as pa_Table
from pyarrow.cffi import ffi as arrow_cffi
from pyarrow.types import is_floating as arrow_is_floating
Expand All @@ -195,6 +197,18 @@ def __init__(self, *args, **kwargs):
except ImportError:
PYARROW_INSTALLED = False

class pa_Array: # type: ignore
"""Dummy class for pa.Array."""

def __init__(self, *args, **kwargs):
pass

class pa_ChunkedArray: # type: ignore
"""Dummy class for pa.ChunkedArray."""

def __init__(self, *args, **kwargs):
pass

class pa_Table: # type: ignore
"""Dummy class for pa.Table."""

Expand Down
16 changes: 16 additions & 0 deletions src/c_api.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -833,6 +833,7 @@ class Booster {

// explicitly declare symbols from LightGBM namespace
using LightGBM::AllgatherFunction;
using LightGBM::ArrowChunkedArray;
using LightGBM::ArrowTable;
using LightGBM::Booster;
using LightGBM::Common::CheckElementsIntervalClosed;
Expand Down Expand Up @@ -1780,6 +1781,21 @@ int LGBM_DatasetSetField(DatasetHandle handle,
API_END();
}

int LGBM_DatasetSetFieldFromArrow(DatasetHandle handle,
const char* field_name,
int64_t n_chunks,
const ArrowArray* chunks,
const ArrowSchema* schema) {
API_BEGIN();
auto dataset = reinterpret_cast<Dataset*>(handle);
ArrowChunkedArray ca(n_chunks, chunks, schema);
auto is_success = dataset->SetFieldFromArrow(field_name, ca);
if (!is_success) {
Log::Fatal("Input field is not supported");
}
API_END();
}

int LGBM_DatasetGetField(DatasetHandle handle,
const char* field_name,
int* out_len,
Expand Down
11 changes: 11 additions & 0 deletions src/io/dataset.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -897,6 +897,17 @@ void Dataset::CopySubrow(const Dataset* fullset,
#endif // USE_CUDA
}

bool Dataset::SetFieldFromArrow(const char* field_name, const ArrowChunkedArray &ca) {
std::string name(field_name);
name = Common::Trim(name);
if (name == std::string("label") || name == std::string("target")) {
metadata_.SetLabel(ca);
} else {
return false;
}
return true;
}

bool Dataset::SetFloatField(const char* field_name, const float* field_data,
data_size_t num_element) {
std::string name(field_name);
Expand Down
28 changes: 20 additions & 8 deletions src/io/metadata.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -403,27 +403,39 @@ void Metadata::InsertInitScores(const double* init_scores, data_size_t start_ind
// CUDA is handled after all insertions are complete
}

void Metadata::SetLabel(const label_t* label, data_size_t len) {
template <typename It>
void Metadata::SetLabelsFromIterator(It first, It last) {
std::lock_guard<std::mutex> lock(mutex_);
if (label == nullptr) {
Log::Fatal("label cannot be nullptr");
if (num_data_ != last - first) {
Log::Fatal("Length of labels differs from the length of #data");
}
if (num_data_ != len) {
Log::Fatal("Length of label is not same with #data");
if (label_.empty()) {
label_.resize(num_data_);
}
if (label_.empty()) { label_.resize(num_data_); }

#pragma omp parallel for num_threads(OMP_NUM_THREADS()) schedule(static, 512) if (num_data_ >= 1024)
for (data_size_t i = 0; i < num_data_; ++i) {
label_[i] = Common::AvoidInf(label[i]);
label_[i] = Common::AvoidInf(first[i]);
}

#ifdef USE_CUDA
if (cuda_metadata_ != nullptr) {
cuda_metadata_->SetLabel(label_.data(), len);
cuda_metadata_->SetLabel(label_.data(), label_.size());
}
#endif // USE_CUDA
}

void Metadata::SetLabel(const label_t* label, data_size_t len) {
if (label == nullptr) {
Log::Fatal("label cannot be nullptr");
}
SetLabelsFromIterator(label, label + len);
}

void Metadata::SetLabel(const ArrowChunkedArray& array) {
SetLabelsFromIterator(array.begin<label_t>(), array.end<label_t>());
}

void Metadata::InsertLabels(const label_t* labels, data_size_t start_index, data_size_t len) {
if (labels == nullptr) {
Log::Fatal("label cannot be nullptr");
Expand Down
46 changes: 46 additions & 0 deletions tests/python_package_test/test_arrow.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,10 @@ def dummy_dataset_params() -> Dict[str, Any]:
}


def assert_arrays_equal(lhs: np.ndarray, rhs: np.ndarray):
assert lhs.dtype == rhs.dtype and np.array_equal(lhs, rhs)


# ----------------------------------------------------------------------------------------------- #
# UNIT TESTS #
# ----------------------------------------------------------------------------------------------- #
Expand Down Expand Up @@ -97,3 +101,45 @@ def test_dataset_construct_fuzzy(
arrow_dataset._dump_text(tmp_path / "arrow.txt")
pandas_dataset._dump_text(tmp_path / "pandas.txt")
assert filecmp.cmp(tmp_path / "arrow.txt", tmp_path / "pandas.txt")


@pytest.mark.parametrize(
["array_type", "label_data"],
[(pa.array, [0, 1, 0, 0, 1]), (pa.chunked_array, [[0], [1, 0, 0, 1]])],
)
@pytest.mark.parametrize(
"arrow_type",
[
pa.int8(),
pa.int16(),
pa.int32(),
pa.int64(),
pa.uint8(),
pa.uint16(),
pa.uint32(),
pa.uint64(),
pa.float32(),
pa.float64(),
],
)
def test_dataset_construct_labels(array_type: Any, label_data: Any, arrow_type: Any):
data = generate_dummy_arrow_table()
labels = array_type(label_data, type=arrow_type)
dataset = lgb.Dataset(data, label=labels, params=dummy_dataset_params())
dataset.construct()

expected = np.array([0, 1, 0, 0, 1], dtype=np.float32)
assert_arrays_equal(expected, dataset.get_label())


def test_dataset_construct_labels_fuzzy():
arrow_table = generate_random_arrow_table(3, 1000, 42)
arrow_array = generate_random_arrow_array(1000, 42)

arrow_dataset = lgb.Dataset(arrow_table, label=arrow_array)
arrow_dataset.construct()

pandas_dataset = lgb.Dataset(arrow_table.to_pandas(), label=arrow_array.to_numpy())
pandas_dataset.construct()

assert_arrays_equal(arrow_dataset.get_label(), pandas_dataset.get_label())

0 comments on commit b7f6311

Please sign in to comment.