Skip to content

Commit

Permalink
[R-package] Fix inefficiency in retrieving pointers (#6208)
Browse files Browse the repository at this point in the history
  • Loading branch information
david-cortes committed Nov 25, 2023
1 parent 516bde9 commit cd36ffe
Showing 1 changed file with 16 additions and 8 deletions.
24 changes: 16 additions & 8 deletions R-package/src/lightgbm_R.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -226,9 +226,10 @@ SEXP LGBM_DatasetGetSubset_R(SEXP handle,
int32_t len = static_cast<int32_t>(Rf_asInteger(len_used_row_indices));
std::vector<int32_t> idxvec(len);
// convert from one-based to zero-based index
const int *used_row_indices_ = INTEGER(used_row_indices);
#pragma omp parallel for num_threads(OMP_NUM_THREADS()) schedule(static, 512) if (len >= 1024)
for (int32_t i = 0; i < len; ++i) {
idxvec[i] = static_cast<int32_t>(INTEGER(used_row_indices)[i] - 1);
idxvec[i] = static_cast<int32_t>(used_row_indices_[i] - 1);
}
const char* parameters_ptr = CHAR(PROTECT(Rf_asChar(parameters)));
DatasetHandle res = nullptr;
Expand Down Expand Up @@ -339,18 +340,20 @@ SEXP LGBM_DatasetSetField_R(SEXP handle,
const char* name = CHAR(PROTECT(Rf_asChar(field_name)));
if (!strcmp("group", name) || !strcmp("query", name)) {
std::vector<int32_t> vec(len);
const int *field_data_ = INTEGER(field_data);
#pragma omp parallel for num_threads(OMP_NUM_THREADS()) schedule(static, 512) if (len >= 1024)
for (int i = 0; i < len; ++i) {
vec[i] = static_cast<int32_t>(INTEGER(field_data)[i]);
vec[i] = static_cast<int32_t>(field_data_[i]);
}
CHECK_CALL(LGBM_DatasetSetField(R_ExternalPtrAddr(handle), name, vec.data(), len, C_API_DTYPE_INT32));
} else if (!strcmp("init_score", name)) {
CHECK_CALL(LGBM_DatasetSetField(R_ExternalPtrAddr(handle), name, REAL(field_data), len, C_API_DTYPE_FLOAT64));
} else {
std::vector<float> vec(len);
const double *field_data_ = REAL(field_data);
#pragma omp parallel for num_threads(OMP_NUM_THREADS()) schedule(static, 512) if (len >= 1024)
for (int i = 0; i < len; ++i) {
vec[i] = static_cast<float>(REAL(field_data)[i]);
vec[i] = static_cast<float>(field_data_[i]);
}
CHECK_CALL(LGBM_DatasetSetField(R_ExternalPtrAddr(handle), name, vec.data(), len, C_API_DTYPE_FLOAT32));
}
Expand All @@ -372,21 +375,24 @@ SEXP LGBM_DatasetGetField_R(SEXP handle,
if (!strcmp("group", name) || !strcmp("query", name)) {
auto p_data = reinterpret_cast<const int32_t*>(res);
// convert from boundaries to size
int *field_data_ = INTEGER(field_data);
#pragma omp parallel for num_threads(OMP_NUM_THREADS()) schedule(static, 512) if (out_len >= 1024)
for (int i = 0; i < out_len - 1; ++i) {
INTEGER(field_data)[i] = p_data[i + 1] - p_data[i];
field_data_[i] = p_data[i + 1] - p_data[i];
}
} else if (!strcmp("init_score", name)) {
auto p_data = reinterpret_cast<const double*>(res);
double *field_data_ = REAL(field_data);
#pragma omp parallel for num_threads(OMP_NUM_THREADS()) schedule(static, 512) if (out_len >= 1024)
for (int i = 0; i < out_len; ++i) {
REAL(field_data)[i] = p_data[i];
field_data_[i] = p_data[i];
}
} else {
auto p_data = reinterpret_cast<const float*>(res);
double *field_data_ = REAL(field_data);
#pragma omp parallel for num_threads(OMP_NUM_THREADS()) schedule(static, 512) if (out_len >= 1024)
for (int i = 0; i < out_len; ++i) {
REAL(field_data)[i] = p_data[i];
field_data_[i] = p_data[i];
}
}
UNPROTECT(1);
Expand Down Expand Up @@ -611,10 +617,12 @@ SEXP LGBM_BoosterUpdateOneIterCustom_R(SEXP handle,
int is_finished = 0;
int int_len = Rf_asInteger(len);
std::vector<float> tgrad(int_len), thess(int_len);
const double *grad_ = REAL(grad);
const double *hess_ = REAL(hess);
#pragma omp parallel for num_threads(OMP_NUM_THREADS()) schedule(static, 512) if (int_len >= 1024)
for (int j = 0; j < int_len; ++j) {
tgrad[j] = static_cast<float>(REAL(grad)[j]);
thess[j] = static_cast<float>(REAL(hess)[j]);
tgrad[j] = static_cast<float>(grad_[j]);
thess[j] = static_cast<float>(hess_[j]);
}
CHECK_CALL(LGBM_BoosterUpdateOneIterCustom(R_ExternalPtrAddr(handle), tgrad.data(), thess.data(), &is_finished));
return R_NilValue;
Expand Down

0 comments on commit cd36ffe

Please sign in to comment.