Skip to content

Commit

Permalink
[python-package] Allow to pass Arrow table for prediction (#6168)
Browse files Browse the repository at this point in the history
  • Loading branch information
borchero committed Dec 14, 2023
1 parent 6fc8052 commit 2dfb9a4
Show file tree
Hide file tree
Showing 4 changed files with 259 additions and 15 deletions.
34 changes: 34 additions & 0 deletions include/LightGBM/c_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -1417,6 +1417,40 @@ LIGHTGBM_C_EXPORT int LGBM_BoosterPredictForMats(BoosterHandle handle,
int64_t* out_len,
double* out_result);

/*!
* \brief Make prediction for a new dataset.
* \note
* You should pre-allocate memory for ``out_result``:
* - for normal and raw score, its length is equal to ``num_class * num_data``;
* - for leaf index, its length is equal to ``num_class * num_data * num_iteration``;
* - for feature contributions, its length is equal to ``num_class * num_data * (num_feature + 1)``.
* \param handle Handle of booster
* \param n_chunks The number of Arrow arrays passed to this function
* \param chunks Pointer to the list of Arrow arrays
* \param schema Pointer to the schema of all Arrow arrays
* \param predict_type What should be predicted
* - ``C_API_PREDICT_NORMAL``: normal prediction, with transform (if needed);
* - ``C_API_PREDICT_RAW_SCORE``: raw score;
* - ``C_API_PREDICT_LEAF_INDEX``: leaf index;
* - ``C_API_PREDICT_CONTRIB``: feature contributions (SHAP values)
* \param start_iteration Start index of the iteration to predict
* \param num_iteration Number of iteration for prediction, <= 0 means no limit
* \param parameter Other parameters for prediction, e.g. early stopping for prediction
* \param[out] out_len Length of output result
* \param[out] out_result Pointer to array with predictions
* \return 0 when succeed, -1 when failure happens
*/
LIGHTGBM_C_EXPORT int LGBM_BoosterPredictForArrow(BoosterHandle handle,
int64_t n_chunks,
const ArrowArray* chunks,
const ArrowSchema* schema,
int predict_type,
int start_iteration,
int num_iteration,
const char* parameter,
int64_t* out_len,
double* out_result);

/*!
* \brief Save model into file.
* \param handle Handle of booster
Expand Down
56 changes: 53 additions & 3 deletions python-package/lightgbm/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,8 @@
np.ndarray,
pd_DataFrame,
dt_DataTable,
scipy.sparse.spmatrix
scipy.sparse.spmatrix,
pa_Table,
]
_LGBM_WeightType = Union[
List[float],
Expand Down Expand Up @@ -1069,7 +1070,7 @@ def predict(
Parameters
----------
data : str, pathlib.Path, numpy array, pandas DataFrame, H2O DataTable's Frame or scipy.sparse
data : str, pathlib.Path, numpy array, pandas DataFrame, pyarrow Table, H2O DataTable's Frame or scipy.sparse
Data source for prediction.
If str or pathlib.Path, it represents the path to a text file (CSV, TSV, or LibSVM).
start_iteration : int, optional (default=0)
Expand Down Expand Up @@ -1161,6 +1162,13 @@ def predict(
num_iteration=num_iteration,
predict_type=predict_type
)
elif _is_pyarrow_table(data):
preds, nrow = self.__pred_for_pyarrow_table(
table=data,
start_iteration=start_iteration,
num_iteration=num_iteration,
predict_type=predict_type
)
elif isinstance(data, list):
try:
data = np.array(data)
Expand Down Expand Up @@ -1614,6 +1622,48 @@ def __pred_for_csc(
if n_preds != out_num_preds.value:
raise ValueError("Wrong length for predict results")
return preds, nrow

def __pred_for_pyarrow_table(
self,
table: pa_Table,
start_iteration: int,
num_iteration: int,
predict_type: int
) -> Tuple[np.ndarray, int]:
"""Predict for a PyArrow table."""
if not PYARROW_INSTALLED:
raise LightGBMError("Cannot predict from Arrow without `pyarrow` installed.")

# Check that the input is valid: we only handle numbers (for now)
if not all(arrow_is_integer(t) or arrow_is_floating(t) for t in table.schema.types):
raise ValueError("Arrow table may only have integer or floating point datatypes")

# Prepare prediction output array
n_preds = self.__get_num_preds(
start_iteration=start_iteration,
num_iteration=num_iteration,
nrow=table.num_rows,
predict_type=predict_type
)
preds = np.empty(n_preds, dtype=np.float64)
out_num_preds = ctypes.c_int64(0)

# Export Arrow table to C and run prediction
c_array = _export_arrow_to_c(table)
_safe_call(_LIB.LGBM_BoosterPredictForArrow(
self._handle,
ctypes.c_int64(c_array.n_chunks),
ctypes.c_void_p(c_array.chunks_ptr),
ctypes.c_void_p(c_array.schema_ptr),
ctypes.c_int(predict_type),
ctypes.c_int(start_iteration),
ctypes.c_int(num_iteration),
_c_str(self.pred_parameter),
ctypes.byref(out_num_preds),
preds.ctypes.data_as(ctypes.POINTER(ctypes.c_double))))
if n_preds != out_num_preds.value:
raise ValueError("Wrong length for predict results")
return preds, table.num_rows

def current_iteration(self) -> int:
"""Get the index of the current iteration.
Expand Down Expand Up @@ -4350,7 +4400,7 @@ def predict(
Parameters
----------
data : str, pathlib.Path, numpy array, pandas DataFrame, H2O DataTable's Frame or scipy.sparse
data : str, pathlib.Path, numpy array, pandas DataFrame, pyarrow Table, H2O DataTable's Frame or scipy.sparse
Data source for prediction.
If str or pathlib.Path, it represents the path to a text file (CSV, TSV, or LibSVM).
start_iteration : int, optional (default=0)
Expand Down
51 changes: 51 additions & 0 deletions src/c_api.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2568,6 +2568,57 @@ int LGBM_BoosterPredictForMats(BoosterHandle handle,
API_END();
}

int LGBM_BoosterPredictForArrow(BoosterHandle handle,
int64_t n_chunks,
const ArrowArray* chunks,
const ArrowSchema* schema,
int predict_type,
int start_iteration,
int num_iteration,
const char* parameter,
int64_t* out_len,
double* out_result) {
API_BEGIN();

// Apply the configuration
auto param = Config::Str2Map(parameter);
Config config;
config.Set(param);
OMP_SET_NUM_THREADS(config.num_threads);

// Set up chunked array and iterators for all columns
ArrowTable table(n_chunks, chunks, schema);
std::vector<ArrowChunkedArray::Iterator<double>> its;
its.reserve(table.get_num_columns());
for (int64_t j = 0; j < table.get_num_columns(); ++j) {
its.emplace_back(table.get_column(j).begin<double>());
}

// Build row function
auto num_columns = table.get_num_columns();
auto row_fn = [num_columns, &its] (int row_idx) {
std::vector<std::pair<int, double>> result;
result.reserve(num_columns);
for (int64_t j = 0; j < num_columns; ++j) {
result.emplace_back(static_cast<int>(j), its[j][row_idx]);
}
return result;
};

// Run prediction
Booster* ref_booster = reinterpret_cast<Booster*>(handle);
ref_booster->Predict(start_iteration,
num_iteration,
predict_type,
static_cast<int>(table.get_num_rows()),
static_cast<int>(table.get_num_columns()),
row_fn,
config,
out_result,
out_len);
API_END();
}

int LGBM_BoosterSaveModel(BoosterHandle handle,
int start_iteration,
int num_iteration,
Expand Down
133 changes: 121 additions & 12 deletions tests/python_package_test/test_arrow.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# coding: utf-8
import filecmp
from typing import Any, Dict
from typing import Any, Dict, Optional

import numpy as np
import pyarrow as pa
Expand Down Expand Up @@ -63,19 +63,40 @@ def generate_dummy_arrow_table() -> pa.Table:
return pa.Table.from_arrays([col1, col2], names=["a", "b"])


def generate_random_arrow_table(num_columns: int, num_datapoints: int, seed: int) -> pa.Table:
columns = [generate_random_arrow_array(num_datapoints, seed + i) for i in range(num_columns)]
def generate_random_arrow_table(
num_columns: int,
num_datapoints: int,
seed: int,
generate_nulls: bool = True,
values: Optional[np.ndarray] = None,
) -> pa.Table:
columns = [
generate_random_arrow_array(
num_datapoints, seed + i, generate_nulls=generate_nulls, values=values
)
for i in range(num_columns)
]
names = [f"col_{i}" for i in range(num_columns)]
return pa.Table.from_arrays(columns, names=names)


def generate_random_arrow_array(num_datapoints: int, seed: int) -> pa.ChunkedArray:
def generate_random_arrow_array(
num_datapoints: int,
seed: int,
generate_nulls: bool = True,
values: Optional[np.ndarray] = None,
) -> pa.ChunkedArray:
generator = np.random.default_rng(seed)
data = generator.standard_normal(num_datapoints)
data = (
generator.standard_normal(num_datapoints)
if values is None
else generator.choice(values, size=num_datapoints, replace=True)
)

# Set random nulls
indices = generator.choice(len(data), size=num_datapoints // 10)
data[indices] = None
if generate_nulls:
indices = generator.choice(len(data), size=num_datapoints // 10)
data[indices] = None

# Split data into <=2 random chunks
split_points = np.sort(generator.choice(np.arange(1, num_datapoints), 2, replace=False))
Expand Down Expand Up @@ -131,8 +152,8 @@ def test_dataset_construct_fuzzy(tmp_path, arrow_table_fn, dataset_params):

def test_dataset_construct_fields_fuzzy():
arrow_table = generate_random_arrow_table(3, 1000, 42)
arrow_labels = generate_random_arrow_array(1000, 42)
arrow_weights = generate_random_arrow_array(1000, 42)
arrow_labels = generate_random_arrow_array(1000, 42, generate_nulls=False)
arrow_weights = generate_random_arrow_array(1000, 42, generate_nulls=False)
arrow_groups = pa.chunked_array([[300, 400, 50], [250]], type=pa.int32())

arrow_dataset = lgb.Dataset(
Expand Down Expand Up @@ -264,9 +285,9 @@ def test_dataset_construct_init_scores_table():
data = generate_dummy_arrow_table()
init_scores = pa.Table.from_arrays(
[
generate_random_arrow_array(5, seed=1),
generate_random_arrow_array(5, seed=2),
generate_random_arrow_array(5, seed=3),
generate_random_arrow_array(5, seed=1, generate_nulls=False),
generate_random_arrow_array(5, seed=2, generate_nulls=False),
generate_random_arrow_array(5, seed=3, generate_nulls=False),
],
names=["a", "b", "c"],
)
Expand All @@ -276,3 +297,91 @@ def test_dataset_construct_init_scores_table():
actual = dataset.get_init_score()
expected = init_scores.to_pandas().to_numpy().astype(np.float64)
np_assert_array_equal(expected, actual, strict=True)


# ------------------------------------------ PREDICTION ----------------------------------------- #


def assert_equal_predict_arrow_pandas(booster: lgb.Booster, data: pa.Table):
p_arrow = booster.predict(data)
p_pandas = booster.predict(data.to_pandas())
np_assert_array_equal(p_arrow, p_pandas, strict=True)

p_raw_arrow = booster.predict(data, raw_score=True)
p_raw_pandas = booster.predict(data.to_pandas(), raw_score=True)
np_assert_array_equal(p_raw_arrow, p_raw_pandas, strict=True)

p_leaf_arrow = booster.predict(data, pred_leaf=True)
p_leaf_pandas = booster.predict(data.to_pandas(), pred_leaf=True)
np_assert_array_equal(p_leaf_arrow, p_leaf_pandas, strict=True)

p_pred_contrib_arrow = booster.predict(data, pred_contrib=True)
p_pred_contrib_pandas = booster.predict(data.to_pandas(), pred_contrib=True)
np_assert_array_equal(p_pred_contrib_arrow, p_pred_contrib_pandas, strict=True)

p_first_iter_arrow = booster.predict(data, start_iteration=0, num_iteration=1, raw_score=True)
p_first_iter_pandas = booster.predict(
data.to_pandas(), start_iteration=0, num_iteration=1, raw_score=True
)
np_assert_array_equal(p_first_iter_arrow, p_first_iter_pandas, strict=True)


def test_predict_regression():
data = generate_random_arrow_table(10, 10000, 42)
dataset = lgb.Dataset(
data,
label=generate_random_arrow_array(10000, 43, generate_nulls=False),
params=dummy_dataset_params(),
)
booster = lgb.train(
{"objective": "regression", "num_leaves": 7},
dataset,
num_boost_round=5,
)
assert_equal_predict_arrow_pandas(booster, data)


def test_predict_binary_classification():
data = generate_random_arrow_table(10, 10000, 42)
dataset = lgb.Dataset(
data,
label=generate_random_arrow_array(10000, 43, generate_nulls=False, values=np.arange(2)),
params=dummy_dataset_params(),
)
booster = lgb.train(
{"objective": "binary", "num_leaves": 7},
dataset,
num_boost_round=5,
)
assert_equal_predict_arrow_pandas(booster, data)


def test_predict_multiclass_classification():
data = generate_random_arrow_table(10, 10000, 42)
dataset = lgb.Dataset(
data,
label=generate_random_arrow_array(10000, 43, generate_nulls=False, values=np.arange(5)),
params=dummy_dataset_params(),
)
booster = lgb.train(
{"objective": "multiclass", "num_leaves": 7, "num_class": 5},
dataset,
num_boost_round=5,
)
assert_equal_predict_arrow_pandas(booster, data)


def test_predict_ranking():
data = generate_random_arrow_table(10, 10000, 42)
dataset = lgb.Dataset(
data,
label=generate_random_arrow_array(10000, 43, generate_nulls=False, values=np.arange(4)),
group=np.array([1000, 2000, 3000, 4000]),
params=dummy_dataset_params(),
)
booster = lgb.train(
{"objective": "lambdarank", "num_leaves": 7},
dataset,
num_boost_round=5,
)
assert_equal_predict_arrow_pandas(booster, data)

0 comments on commit 2dfb9a4

Please sign in to comment.