Skip to content

Commit

Permalink
[python-package] Allow to pass Arrow array as weights (#6164)
Browse files Browse the repository at this point in the history
  • Loading branch information
borchero committed Nov 13, 2023
1 parent 501e6e6 commit deb7077
Show file tree
Hide file tree
Showing 7 changed files with 108 additions and 32 deletions.
4 changes: 2 additions & 2 deletions include/LightGBM/c_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -558,9 +558,9 @@ LIGHTGBM_C_EXPORT int LGBM_DatasetSetField(DatasetHandle handle,
/*!
* \brief Set vector to a content in info.
* \note
* - \a label convert input datatype into ``float32``.
* - \a label and \a weight convert input datatype into ``float32``.
* \param handle Handle of dataset
* \param field_name Field name, can be \a label
* \param field_name Field name, can be \a label, \a weight
* \param n_chunks The number of Arrow arrays passed to this function
* \param chunks Pointer to the list of Arrow arrays
* \param schema Pointer to the schema of all Arrow arrays
Expand Down
4 changes: 4 additions & 0 deletions include/LightGBM/dataset.h
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,7 @@ class Metadata {
void SetLabel(const ArrowChunkedArray& array);

void SetWeights(const label_t* weights, data_size_t len);
void SetWeights(const ArrowChunkedArray& array);

void SetQuery(const data_size_t* query, data_size_t len);

Expand Down Expand Up @@ -340,6 +341,9 @@ class Metadata {
void SetLabelsFromIterator(It first, It last);
/*! \brief Insert weights at the given index */
void InsertWeights(const label_t* weights, data_size_t start_index, data_size_t len);
/*! \brief Set weights from pointers to the first element and the end of an iterator. */
template <typename It>
void SetWeightsFromIterator(It first, It last);
/*! \brief Insert initial scores at the given index */
void InsertInitScores(const double* init_scores, data_size_t start_index, data_size_t len, data_size_t source_size);
/*! \brief Insert queries at the given index */
Expand Down
29 changes: 20 additions & 9 deletions python-package/lightgbm/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,8 @@
import scipy.sparse

from .compat import (PANDAS_INSTALLED, PYARROW_INSTALLED, arrow_cffi, arrow_is_floating, arrow_is_integer, concat,
dt_DataTable, pa_Array, pa_ChunkedArray, pa_Table, pd_CategoricalDtype, pd_DataFrame, pd_Series)
dt_DataTable, pa_Array, pa_ChunkedArray, pa_compute, pa_Table, pd_CategoricalDtype, pd_DataFrame,
pd_Series)
from .libpath import find_lib_path

if TYPE_CHECKING:
Expand Down Expand Up @@ -115,7 +116,9 @@
List[float],
List[int],
np.ndarray,
pd_Series
pd_Series,
pa_Array,
pa_ChunkedArray,
]
ZERO_THRESHOLD = 1e-35

Expand Down Expand Up @@ -1635,7 +1638,7 @@ def __init__(
Label of the data.
reference : Dataset or None, optional (default=None)
If this is Dataset for validation, training data should be used as reference.
weight : list, numpy 1-D array, pandas Series or None, optional (default=None)
weight : list, numpy 1-D array, pandas Series, pyarrow Array, pyarrow ChunkedArray or None, optional (default=None)
Weight for each instance. Weights should be non-negative.
group : list, numpy 1-D array, pandas Series or None, optional (default=None)
Group/query data.
Expand Down Expand Up @@ -2415,7 +2418,7 @@ def create_valid(
If str or pathlib.Path, it represents the path to a text file (CSV, TSV, or LibSVM) or a LightGBM Dataset binary file.
label : list, numpy 1-D array, pandas Series / one-column DataFrame, pyarrow Array, pyarrow ChunkedArray or None, optional (default=None)
Label of the data.
weight : list, numpy 1-D array, pandas Series or None, optional (default=None)
weight : list, numpy 1-D array, pandas Series, pyarrow Array, pyarrow ChunkedArray or None, optional (default=None)
Weight for each instance. Weights should be non-negative.
group : list, numpy 1-D array, pandas Series or None, optional (default=None)
Group/query data.
Expand Down Expand Up @@ -2830,19 +2833,27 @@ def set_weight(
Parameters
----------
weight : list, numpy 1-D array, pandas Series or None
weight : list, numpy 1-D array, pandas Series, pyarrow Array, pyarrow ChunkedArray or None
Weight to be set for each data point. Weights should be non-negative.
Returns
-------
self : Dataset
Dataset with set weight.
"""
if weight is not None and np.all(weight == 1):
weight = None
# Check if the weight contains values other than one
if weight is not None:
if _is_pyarrow_array(weight):
if pa_compute.all(pa_compute.equal(weight, 1)).as_py():
weight = None
elif np.all(weight == 1):
weight = None
self.weight = weight

# Set field
if self._handle is not None and weight is not None:
weight = _list_to_1d_numpy(weight, dtype=np.float32, name='weight')
if not _is_pyarrow_array(weight):
weight = _list_to_1d_numpy(weight, dtype=np.float32, name='weight')
self.set_field('weight', weight)
self.weight = self.get_field('weight') # original values can be modified at cpp side
return self
Expand Down Expand Up @@ -4414,7 +4425,7 @@ def refit(
.. versionadded:: 4.0.0
weight : list, numpy 1-D array, pandas Series or None, optional (default=None)
weight : list, numpy 1-D array, pandas Series, pyarrow Array, pyarrow ChunkedArray or None, optional (default=None)
Weight for each ``data`` instance. Weights should be non-negative.
.. versionadded:: 4.0.0
Expand Down
7 changes: 7 additions & 0 deletions python-package/lightgbm/compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,7 @@ def __init__(self, *args, **kwargs):

"""pyarrow"""
try:
import pyarrow.compute as pa_compute
from pyarrow import Array as pa_Array
from pyarrow import ChunkedArray as pa_ChunkedArray
from pyarrow import Table as pa_Table
Expand Down Expand Up @@ -236,6 +237,12 @@ class arrow_cffi: # type: ignore
def __init__(self, *args, **kwargs):
pass

class pa_compute: # type: ignore
"""Dummy class for pyarrow.compute."""

all = None
equal = None

arrow_is_integer = None
arrow_is_floating = None

Expand Down
2 changes: 2 additions & 0 deletions src/io/dataset.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -902,6 +902,8 @@ bool Dataset::SetFieldFromArrow(const char* field_name, const ArrowChunkedArray
name = Common::Trim(name);
if (name == std::string("label") || name == std::string("target")) {
metadata_.SetLabel(ca);
} else if (name == std::string("weight") || name == std::string("weights")) {
metadata_.SetWeights(ca);
} else {
return false;
}
Expand Down
28 changes: 20 additions & 8 deletions src/io/metadata.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -450,33 +450,45 @@ void Metadata::InsertLabels(const label_t* labels, data_size_t start_index, data
// CUDA is handled after all insertions are complete
}

void Metadata::SetWeights(const label_t* weights, data_size_t len) {
template <typename It>
void Metadata::SetWeightsFromIterator(It first, It last) {
std::lock_guard<std::mutex> lock(mutex_);
// save to nullptr
if (weights == nullptr || len == 0) {
// Clear weights on empty input
if (last - first == 0) {
weights_.clear();
num_weights_ = 0;
return;
}
if (num_data_ != len) {
Log::Fatal("Length of weights is not same with #data");
if (num_data_ != last - first) {
Log::Fatal("Length of weights differs from the length of #data");
}
if (weights_.empty()) {
weights_.resize(num_data_);
}
if (weights_.empty()) { weights_.resize(num_data_); }
num_weights_ = num_data_;

#pragma omp parallel for num_threads(OMP_NUM_THREADS()) schedule(static, 512) if (num_weights_ >= 1024)
for (data_size_t i = 0; i < num_weights_; ++i) {
weights_[i] = Common::AvoidInf(weights[i]);
weights_[i] = Common::AvoidInf(first[i]);
}
CalculateQueryWeights();
weight_load_from_file_ = false;

#ifdef USE_CUDA
if (cuda_metadata_ != nullptr) {
cuda_metadata_->SetWeights(weights_.data(), len);
cuda_metadata_->SetWeights(weights_.data(), weights_.size());
}
#endif // USE_CUDA
}

void Metadata::SetWeights(const label_t* weights, data_size_t len) {
SetWeightsFromIterator(weights, weights + len);
}

void Metadata::SetWeights(const ArrowChunkedArray& array) {
SetWeightsFromIterator(array.begin<label_t>(), array.end<label_t>());
}

void Metadata::InsertWeights(const label_t* weights, data_size_t start_index, data_size_t len) {
if (!weights) {
Log::Fatal("Passed null weights");
Expand Down
66 changes: 53 additions & 13 deletions tests/python_package_test/test_arrow.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@

import lightgbm as lgb

from .utils import np_assert_array_equal

# ----------------------------------------------------------------------------------------------- #
# UTILITIES #
# ----------------------------------------------------------------------------------------------- #
Expand Down Expand Up @@ -67,10 +69,6 @@ def dummy_dataset_params() -> Dict[str, Any]:
}


def assert_arrays_equal(lhs: np.ndarray, rhs: np.ndarray):
assert lhs.dtype == rhs.dtype and np.array_equal(lhs, rhs)


# ----------------------------------------------------------------------------------------------- #
# UNIT TESTS #
# ----------------------------------------------------------------------------------------------- #
Expand Down Expand Up @@ -103,6 +101,34 @@ def test_dataset_construct_fuzzy(
assert filecmp.cmp(tmp_path / "arrow.txt", tmp_path / "pandas.txt")


# -------------------------------------------- FIELDS ------------------------------------------- #


def test_dataset_construct_fields_fuzzy():
arrow_table = generate_random_arrow_table(3, 1000, 42)
arrow_labels = generate_random_arrow_array(1000, 42)
arrow_weights = generate_random_arrow_array(1000, 42)

arrow_dataset = lgb.Dataset(arrow_table, label=arrow_labels, weight=arrow_weights)
arrow_dataset.construct()

pandas_dataset = lgb.Dataset(
arrow_table.to_pandas(), label=arrow_labels.to_numpy(), weight=arrow_weights.to_numpy()
)
pandas_dataset.construct()

# Check for equality
for field in ("label", "weight"):
np_assert_array_equal(
arrow_dataset.get_field(field), pandas_dataset.get_field(field), strict=True
)
np_assert_array_equal(arrow_dataset.get_label(), pandas_dataset.get_label(), strict=True)
np_assert_array_equal(arrow_dataset.get_weight(), pandas_dataset.get_weight(), strict=True)


# -------------------------------------------- LABELS ------------------------------------------- #


@pytest.mark.parametrize(
["array_type", "label_data"],
[(pa.array, [0, 1, 0, 0, 1]), (pa.chunked_array, [[0], [1, 0, 0, 1]])],
Expand All @@ -129,17 +155,31 @@ def test_dataset_construct_labels(array_type: Any, label_data: Any, arrow_type:
dataset.construct()

expected = np.array([0, 1, 0, 0, 1], dtype=np.float32)
assert_arrays_equal(expected, dataset.get_label())
np_assert_array_equal(expected, dataset.get_label(), strict=True)


def test_dataset_construct_labels_fuzzy():
arrow_table = generate_random_arrow_table(3, 1000, 42)
arrow_array = generate_random_arrow_array(1000, 42)
# ------------------------------------------- WEIGHTS ------------------------------------------- #

arrow_dataset = lgb.Dataset(arrow_table, label=arrow_array)
arrow_dataset.construct()

pandas_dataset = lgb.Dataset(arrow_table.to_pandas(), label=arrow_array.to_numpy())
pandas_dataset.construct()
def test_dataset_construct_weights_none():
data = generate_dummy_arrow_table()
weight = pa.array([1, 1, 1, 1, 1])
dataset = lgb.Dataset(data, weight=weight, params=dummy_dataset_params())
dataset.construct()
assert dataset.get_weight() is None
assert dataset.get_field("weight") is None


@pytest.mark.parametrize(
["array_type", "weight_data"],
[(pa.array, [3, 0.7, 1.5, 0.5, 0.1]), (pa.chunked_array, [[3], [0.7, 1.5, 0.5, 0.1]])],
)
@pytest.mark.parametrize("arrow_type", [pa.float32(), pa.float64()])
def test_dataset_construct_weights(array_type: Any, weight_data: Any, arrow_type: Any):
data = generate_dummy_arrow_table()
weights = array_type(weight_data, type=arrow_type)
dataset = lgb.Dataset(data, weight=weights, params=dummy_dataset_params())
dataset.construct()

assert_arrays_equal(arrow_dataset.get_label(), pandas_dataset.get_label())
expected = np.array([3, 0.7, 1.5, 0.5, 0.1], dtype=np.float32)
np_assert_array_equal(expected, dataset.get_weight(), strict=True)

0 comments on commit deb7077

Please sign in to comment.