Skip to content

Commit

Permalink
set explicit number of threads in every OpenMP parallel region (#6135)
Browse files Browse the repository at this point in the history
  • Loading branch information
jameslamb committed Oct 10, 2023
1 parent 992f505 commit 8ed371c
Show file tree
Hide file tree
Showing 45 changed files with 226 additions and 201 deletions.
25 changes: 25 additions & 0 deletions .ci/lint-cpp.sh
Original file line number Diff line number Diff line change
Expand Up @@ -18,3 +18,28 @@ cmakelint \
${cmake_files} \
|| exit -1
echo "done running cmakelint"

echo "checking that all OpenMP pragmas specify num_threads()"
get_omp_pragmas_without_num_threads() {
grep \
-n \
-R \
--include='*.c' \
--include='*.cc' \
--include='*.cpp' \
--include='*.h' \
--include='*.hpp' \
'pragma omp parallel' \
| grep -v ' num_threads' \
| grep -v 'openmp_wrapper.h'
}
PROBLEMATIC_LINES=$(
get_omp_pragmas_without_num_threads
)
if test "${PROBLEMATIC_LINES}" != ""; then
get_omp_pragmas_without_num_threads
echo "Found '#pragma omp parallel' not using explicit num_threads() configuration. Fix those."
echo "For details, see https://www.openmp.org/spec-html/5.0/openmpse14.html#x54-800002.6"
exit -1
fi
echo "done checking OpenMP pragmas"
14 changes: 7 additions & 7 deletions R-package/src/lightgbm_R.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -226,7 +226,7 @@ SEXP LGBM_DatasetGetSubset_R(SEXP handle,
int32_t len = static_cast<int32_t>(Rf_asInteger(len_used_row_indices));
std::vector<int32_t> idxvec(len);
// convert from one-based to zero-based index
#pragma omp parallel for schedule(static, 512) if (len >= 1024)
#pragma omp parallel for num_threads(OMP_NUM_THREADS()) schedule(static, 512) if (len >= 1024)
for (int32_t i = 0; i < len; ++i) {
idxvec[i] = static_cast<int32_t>(INTEGER(used_row_indices)[i] - 1);
}
Expand Down Expand Up @@ -339,7 +339,7 @@ SEXP LGBM_DatasetSetField_R(SEXP handle,
const char* name = CHAR(PROTECT(Rf_asChar(field_name)));
if (!strcmp("group", name) || !strcmp("query", name)) {
std::vector<int32_t> vec(len);
#pragma omp parallel for schedule(static, 512) if (len >= 1024)
#pragma omp parallel for num_threads(OMP_NUM_THREADS()) schedule(static, 512) if (len >= 1024)
for (int i = 0; i < len; ++i) {
vec[i] = static_cast<int32_t>(INTEGER(field_data)[i]);
}
Expand All @@ -348,7 +348,7 @@ SEXP LGBM_DatasetSetField_R(SEXP handle,
CHECK_CALL(LGBM_DatasetSetField(R_ExternalPtrAddr(handle), name, REAL(field_data), len, C_API_DTYPE_FLOAT64));
} else {
std::vector<float> vec(len);
#pragma omp parallel for schedule(static, 512) if (len >= 1024)
#pragma omp parallel for num_threads(OMP_NUM_THREADS()) schedule(static, 512) if (len >= 1024)
for (int i = 0; i < len; ++i) {
vec[i] = static_cast<float>(REAL(field_data)[i]);
}
Expand All @@ -372,19 +372,19 @@ SEXP LGBM_DatasetGetField_R(SEXP handle,
if (!strcmp("group", name) || !strcmp("query", name)) {
auto p_data = reinterpret_cast<const int32_t*>(res);
// convert from boundaries to size
#pragma omp parallel for schedule(static, 512) if (out_len >= 1024)
#pragma omp parallel for num_threads(OMP_NUM_THREADS()) schedule(static, 512) if (out_len >= 1024)
for (int i = 0; i < out_len - 1; ++i) {
INTEGER(field_data)[i] = p_data[i + 1] - p_data[i];
}
} else if (!strcmp("init_score", name)) {
auto p_data = reinterpret_cast<const double*>(res);
#pragma omp parallel for schedule(static, 512) if (out_len >= 1024)
#pragma omp parallel for num_threads(OMP_NUM_THREADS()) schedule(static, 512) if (out_len >= 1024)
for (int i = 0; i < out_len; ++i) {
REAL(field_data)[i] = p_data[i];
}
} else {
auto p_data = reinterpret_cast<const float*>(res);
#pragma omp parallel for schedule(static, 512) if (out_len >= 1024)
#pragma omp parallel for num_threads(OMP_NUM_THREADS()) schedule(static, 512) if (out_len >= 1024)
for (int i = 0; i < out_len; ++i) {
REAL(field_data)[i] = p_data[i];
}
Expand Down Expand Up @@ -611,7 +611,7 @@ SEXP LGBM_BoosterUpdateOneIterCustom_R(SEXP handle,
int is_finished = 0;
int int_len = Rf_asInteger(len);
std::vector<float> tgrad(int_len), thess(int_len);
#pragma omp parallel for schedule(static, 512) if (int_len >= 1024)
#pragma omp parallel for num_threads(OMP_NUM_THREADS()) schedule(static, 512) if (int_len >= 1024)
for (int j = 0; j < int_len; ++j) {
tgrad[j] = static_cast<float>(REAL(grad)[j]);
thess[j] = static_cast<float>(REAL(hess)[j]);
Expand Down
2 changes: 1 addition & 1 deletion include/LightGBM/feature_group.h
Original file line number Diff line number Diff line change
Expand Up @@ -361,7 +361,7 @@ class FeatureGroup {
inline void FinishLoad() {
if (is_multi_val_) {
OMP_INIT_EX();
#pragma omp parallel for schedule(guided)
#pragma omp parallel for num_threads(OMP_NUM_THREADS()) schedule(guided)
for (int i = 0; i < num_feature_; ++i) {
OMP_LOOP_EX_BEGIN();
multi_bin_data_[i]->FinishLoad();
Expand Down
6 changes: 3 additions & 3 deletions include/LightGBM/tree.h
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,7 @@ class Tree {
* \param rate The factor of shrinkage
*/
virtual inline void Shrinkage(double rate) {
#pragma omp parallel for schedule(static, 1024) if (num_leaves_ >= 2048)
#pragma omp parallel for num_threads(OMP_NUM_THREADS()) schedule(static, 1024) if (num_leaves_ >= 2048)
for (int i = 0; i < num_leaves_ - 1; ++i) {
leaf_value_[i] = MaybeRoundToZero(leaf_value_[i] * rate);
internal_value_[i] = MaybeRoundToZero(internal_value_[i] * rate);
Expand All @@ -210,15 +210,15 @@ class Tree {
inline double shrinkage() const { return shrinkage_; }

virtual inline void AddBias(double val) {
#pragma omp parallel for schedule(static, 1024) if (num_leaves_ >= 2048)
#pragma omp parallel for num_threads(OMP_NUM_THREADS()) schedule(static, 1024) if (num_leaves_ >= 2048)
for (int i = 0; i < num_leaves_ - 1; ++i) {
leaf_value_[i] = MaybeRoundToZero(leaf_value_[i] + val);
internal_value_[i] = MaybeRoundToZero(internal_value_[i] + val);
}
leaf_value_[num_leaves_ - 1] =
MaybeRoundToZero(leaf_value_[num_leaves_ - 1] + val);
if (is_linear_) {
#pragma omp parallel for schedule(static, 1024) if (num_leaves_ >= 2048)
#pragma omp parallel for num_threads(OMP_NUM_THREADS()) schedule(static, 1024) if (num_leaves_ >= 2048)
for (int i = 0; i < num_leaves_ - 1; ++i) {
leaf_const_[i] = MaybeRoundToZero(leaf_const_[i] + val);
}
Expand Down
4 changes: 2 additions & 2 deletions include/LightGBM/utils/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -691,7 +691,7 @@ static void ParallelSort(_RanIt _First, _RanIt _Last, _Pr _Pred, _VTRanIt*) {
size_t inner_size = (len + num_threads - 1) / num_threads;
inner_size = std::max(inner_size, kMinInnerLen);
num_threads = static_cast<int>((len + inner_size - 1) / inner_size);
#pragma omp parallel for schedule(static, 1)
#pragma omp parallel for num_threads(num_threads) schedule(static, 1)
for (int i = 0; i < num_threads; ++i) {
size_t left = inner_size*i;
size_t right = left + inner_size;
Expand All @@ -707,7 +707,7 @@ static void ParallelSort(_RanIt _First, _RanIt _Last, _Pr _Pred, _VTRanIt*) {
// Recursive merge
while (s < len) {
int loop_size = static_cast<int>((len + s * 2 - 1) / (s * 2));
#pragma omp parallel for schedule(static, 1)
#pragma omp parallel for num_threads(num_threads) schedule(static, 1)
for (int i = 0; i < loop_size; ++i) {
size_t left = i * 2 * s;
size_t mid = left + s;
Expand Down
2 changes: 1 addition & 1 deletion include/LightGBM/utils/threading.h
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ class Threading {
INDEX_T num_inner = end - start;
BlockInfo<INDEX_T>(num_inner, min_block_size, &n_block, &num_inner);
OMP_INIT_EX();
#pragma omp parallel for schedule(static, 1)
#pragma omp parallel for num_threads(OMP_NUM_THREADS()) schedule(static, 1)
for (int i = 0; i < n_block; ++i) {
OMP_LOOP_EX_BEGIN();
INDEX_T inner_start = start + num_inner * i;
Expand Down
2 changes: 1 addition & 1 deletion src/application/application.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -227,7 +227,7 @@ void Application::Predict() {
TextReader<int> result_reader(config_.output_result.c_str(), false);
result_reader.ReadAllLines();
std::vector<std::vector<int>> pred_leaf(result_reader.Lines().size());
#pragma omp parallel for schedule(static)
#pragma omp parallel for num_threads(OMP_NUM_THREADS()) schedule(static)
for (int i = 0; i < static_cast<int>(result_reader.Lines().size()); ++i) {
pred_leaf[i] = Common::StringToArray<int>(result_reader.Lines()[i], '\t');
// Free memory
Expand Down
2 changes: 1 addition & 1 deletion src/application/predictor.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -233,7 +233,7 @@ class Predictor {
std::vector<std::pair<int, double>> oneline_features;
std::vector<std::string> result_to_write(lines.size());
OMP_INIT_EX();
#pragma omp parallel for schedule(static) firstprivate(oneline_features)
#pragma omp parallel for num_threads(OMP_NUM_THREADS()) schedule(static) firstprivate(oneline_features)
for (data_size_t i = 0; i < static_cast<data_size_t>(lines.size()); ++i) {
OMP_LOOP_EX_BEGIN();
oneline_features.clear();
Expand Down
10 changes: 5 additions & 5 deletions src/boosting/gbdt.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -255,7 +255,7 @@ void GBDT::RefitTree(const std::vector<std::vector<int>>& tree_leaf_prediction)
std::vector<int> leaf_pred(num_data_);
if (linear_tree_) {
std::vector<int> max_leaves_by_thread = std::vector<int>(OMP_NUM_THREADS(), 0);
#pragma omp parallel for schedule(static)
#pragma omp parallel for num_threads(OMP_NUM_THREADS()) schedule(static)
for (int i = 0; i < static_cast<int>(tree_leaf_prediction.size()); ++i) {
int tid = omp_get_thread_num();
for (size_t j = 0; j < tree_leaf_prediction[i].size(); ++j) {
Expand All @@ -270,7 +270,7 @@ void GBDT::RefitTree(const std::vector<std::vector<int>>& tree_leaf_prediction)
Boosting();
for (int tree_id = 0; tree_id < num_tree_per_iteration_; ++tree_id) {
int model_index = iter * num_tree_per_iteration_ + tree_id;
#pragma omp parallel for schedule(static)
#pragma omp parallel for num_threads(OMP_NUM_THREADS()) schedule(static)
for (int i = 0; i < num_data_; ++i) {
leaf_pred[i] = tree_leaf_prediction[i][model_index];
CHECK_LT(leaf_pred[i], models_[model_index]->num_leaves());
Expand Down Expand Up @@ -348,7 +348,7 @@ bool GBDT::TrainOneIter(const score_t* gradients, const score_t* hessians) {
if (data_sample_strategy_->IsHessianChange()) {
// need to copy customized gradients when using GOSS
int64_t total_size = static_cast<int64_t>(num_data_) * num_tree_per_iteration_;
#pragma omp parallel for schedule(static)
#pragma omp parallel for num_threads(OMP_NUM_THREADS()) schedule(static)
for (int64_t i = 0; i < total_size; ++i) {
gradients_[i] = gradients[i];
hessians_[i] = hessians[i];
Expand Down Expand Up @@ -669,7 +669,7 @@ void GBDT::GetPredictAt(int data_idx, double* out_result, int64_t* out_len) {
}
#endif // USE_CUDA
if (objective_function_ != nullptr) {
#pragma omp parallel for schedule(static)
#pragma omp parallel for num_threads(OMP_NUM_THREADS()) schedule(static)
for (data_size_t i = 0; i < num_data; ++i) {
std::vector<double> tree_pred(num_tree_per_iteration_);
for (int j = 0; j < num_tree_per_iteration_; ++j) {
Expand All @@ -682,7 +682,7 @@ void GBDT::GetPredictAt(int data_idx, double* out_result, int64_t* out_len) {
}
}
} else {
#pragma omp parallel for schedule(static)
#pragma omp parallel for num_threads(OMP_NUM_THREADS()) schedule(static)
for (data_size_t i = 0; i < num_data; ++i) {
for (int j = 0; j < num_tree_per_iteration_; ++j) {
out_result[j * num_data + i] = static_cast<double>(raw_scores[j * num_data + i]);
Expand Down
2 changes: 1 addition & 1 deletion src/boosting/gbdt.h
Original file line number Diff line number Diff line change
Expand Up @@ -434,7 +434,7 @@ class GBDT : public GBDTBase {
}
start_iteration_for_pred_ = start_iteration;
if (is_pred_contrib) {
#pragma omp parallel for schedule(static)
#pragma omp parallel for num_threads(OMP_NUM_THREADS()) schedule(static)
for (int i = 0; i < static_cast<int>(models_.size()); ++i) {
models_[i]->RecomputeMaxDepth();
}
Expand Down
4 changes: 2 additions & 2 deletions src/boosting/gbdt_model_text.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -354,7 +354,7 @@ std::string GBDT::SaveModelToString(int start_iteration, int num_iteration, int
std::vector<std::string> tree_strs(num_used_model - start_model);
std::vector<size_t> tree_sizes(num_used_model - start_model);
// output tree models
#pragma omp parallel for schedule(static)
#pragma omp parallel for num_threads(OMP_NUM_THREADS()) schedule(static)
for (int i = start_model; i < num_used_model; ++i) {
const int idx = i - start_model;
tree_strs[idx] = "Tree=" + std::to_string(idx) + '\n';
Expand Down Expand Up @@ -552,7 +552,7 @@ bool GBDT::LoadModelFromString(const char* buffer, size_t len) {
models_.emplace_back(nullptr);
}
OMP_INIT_EX();
#pragma omp parallel for schedule(static)
#pragma omp parallel for num_threads(OMP_NUM_THREADS()) schedule(static)
for (int i = 0; i < num_trees; ++i) {
OMP_LOOP_EX_BEGIN();
auto cur_p = p + tree_boundries[i];
Expand Down
2 changes: 1 addition & 1 deletion src/boosting/rf.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ class RF : public GBDT {
}
size_t total_size = static_cast<size_t>(num_data_) * num_tree_per_iteration_;
std::vector<double> tmp_scores(total_size, 0.0f);
#pragma omp parallel for schedule(static)
#pragma omp parallel for num_threads(OMP_NUM_THREADS()) schedule(static)
for (int j = 0; j < num_tree_per_iteration_; ++j) {
size_t offset = static_cast<size_t>(j)* num_data_;
for (data_size_t i = 0; i < num_data_; ++i) {
Expand Down
6 changes: 3 additions & 3 deletions src/boosting/score_updater.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ class ScoreUpdater {
Log::Fatal("Number of class for initial score error");
}
has_init_score_ = true;
#pragma omp parallel for schedule(static, 512) if (total_size >= 1024)
#pragma omp parallel for num_threads(OMP_NUM_THREADS()) schedule(static, 512) if (total_size >= 1024)
for (int64_t i = 0; i < total_size; ++i) {
score_[i] = init_score[i];
}
Expand All @@ -54,15 +54,15 @@ class ScoreUpdater {
virtual inline void AddScore(double val, int cur_tree_id) {
Common::FunctionTimer fun_timer("ScoreUpdater::AddScore", global_timer);
const size_t offset = static_cast<size_t>(num_data_) * cur_tree_id;
#pragma omp parallel for schedule(static, 512) if (num_data_ >= 1024)
#pragma omp parallel for num_threads(OMP_NUM_THREADS()) schedule(static, 512) if (num_data_ >= 1024)
for (int i = 0; i < num_data_; ++i) {
score_[offset + i] += val;
}
}

virtual inline void MultiplyScore(double val, int cur_tree_id) {
const size_t offset = static_cast<size_t>(num_data_) * cur_tree_id;
#pragma omp parallel for schedule(static, 512) if (num_data_ >= 1024)
#pragma omp parallel for num_threads(OMP_NUM_THREADS()) schedule(static, 512) if (num_data_ >= 1024)
for (int i = 0; i < num_data_; ++i) {
score_[offset + i] *= val;
}
Expand Down
26 changes: 13 additions & 13 deletions src/c_api.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -437,7 +437,7 @@ class Booster {
int64_t num_pred_in_one_row = boosting_->NumPredictOneRow(start_iteration, num_iteration, is_predict_leaf, predict_contrib);
auto pred_fun = predictor.GetPredictFunction();
OMP_INIT_EX();
#pragma omp parallel for schedule(static)
#pragma omp parallel for num_threads(OMP_NUM_THREADS()) schedule(static)
for (int i = 0; i < nrow; ++i) {
OMP_LOOP_EX_BEGIN();
auto one_row = get_row_fun(i);
Expand All @@ -459,7 +459,7 @@ class Booster {
auto pred_sparse_fun = predictor.GetPredictSparseFunction();
std::vector<std::vector<std::unordered_map<int, double>>>& agg = *agg_ptr;
OMP_INIT_EX();
#pragma omp parallel for schedule(static)
#pragma omp parallel for num_threads(OMP_NUM_THREADS()) schedule(static)
for (int64_t i = 0; i < nrow; ++i) {
OMP_LOOP_EX_BEGIN();
auto one_row = get_row_fun(i);
Expand Down Expand Up @@ -551,7 +551,7 @@ class Booster {
indptr_index++;
int64_t matrix_start_index = m * static_cast<int64_t>(agg.size());
OMP_INIT_EX();
#pragma omp parallel for schedule(static)
#pragma omp parallel for num_threads(OMP_NUM_THREADS()) schedule(static)
for (int64_t i = 0; i < static_cast<int64_t>(agg.size()); ++i) {
OMP_LOOP_EX_BEGIN();
auto row_vector = agg[i];
Expand Down Expand Up @@ -663,7 +663,7 @@ class Booster {
}
// Note: we parallelize across matrices instead of rows because of the column_counts[m][col_idx] increment inside the loop
OMP_INIT_EX();
#pragma omp parallel for schedule(static)
#pragma omp parallel for num_threads(OMP_NUM_THREADS()) schedule(static)
for (int m = 0; m < num_matrices; ++m) {
OMP_LOOP_EX_BEGIN();
for (int64_t i = 0; i < static_cast<int64_t>(agg.size()); ++i) {
Expand Down Expand Up @@ -1074,7 +1074,7 @@ int LGBM_DatasetPushRows(DatasetHandle dataset,
p_dataset->ResizeRaw(p_dataset->num_numeric_features() + nrow);
}
OMP_INIT_EX();
#pragma omp parallel for schedule(static)
#pragma omp parallel for num_threads(OMP_NUM_THREADS()) schedule(static)
for (int i = 0; i < nrow; ++i) {
OMP_LOOP_EX_BEGIN();
const int tid = omp_get_thread_num();
Expand Down Expand Up @@ -1116,7 +1116,7 @@ int LGBM_DatasetPushRowsWithMetadata(DatasetHandle dataset,
const int max_omp_threads = p_dataset->omp_max_threads() > 0 ? p_dataset->omp_max_threads() : OMP_NUM_THREADS();

OMP_INIT_EX();
#pragma omp parallel for schedule(static)
#pragma omp parallel for num_threads(OMP_NUM_THREADS()) schedule(static)
for (int i = 0; i < nrow; ++i) {
OMP_LOOP_EX_BEGIN();
// convert internal thread id to be unique based on external thread id
Expand Down Expand Up @@ -1153,7 +1153,7 @@ int LGBM_DatasetPushRowsByCSR(DatasetHandle dataset,
p_dataset->ResizeRaw(p_dataset->num_numeric_features() + nrow);
}
OMP_INIT_EX();
#pragma omp parallel for schedule(static)
#pragma omp parallel for num_threads(OMP_NUM_THREADS()) schedule(static)
for (int i = 0; i < nrow; ++i) {
OMP_LOOP_EX_BEGIN();
const int tid = omp_get_thread_num();
Expand Down Expand Up @@ -1199,7 +1199,7 @@ int LGBM_DatasetPushRowsByCSRWithMetadata(DatasetHandle dataset,
const int max_omp_threads = p_dataset->omp_max_threads() > 0 ? p_dataset->omp_max_threads() : OMP_NUM_THREADS();

OMP_INIT_EX();
#pragma omp parallel for schedule(static)
#pragma omp parallel for num_threads(OMP_NUM_THREADS()) schedule(static)
for (int i = 0; i < nrow; ++i) {
OMP_LOOP_EX_BEGIN();
// convert internal thread id to be unique based on external thread id
Expand Down Expand Up @@ -1319,7 +1319,7 @@ int LGBM_DatasetCreateFromMats(int32_t nmat,
int32_t start_row = 0;
for (int j = 0; j < nmat; ++j) {
OMP_INIT_EX();
#pragma omp parallel for schedule(static)
#pragma omp parallel for num_threads(OMP_NUM_THREADS()) schedule(static)
for (int i = 0; i < nrow[j]; ++i) {
OMP_LOOP_EX_BEGIN();
const int tid = omp_get_thread_num();
Expand Down Expand Up @@ -1394,7 +1394,7 @@ int LGBM_DatasetCreateFromCSR(const void* indptr,
}
}
OMP_INIT_EX();
#pragma omp parallel for schedule(static)
#pragma omp parallel for num_threads(OMP_NUM_THREADS()) schedule(static)
for (int i = 0; i < nindptr - 1; ++i) {
OMP_LOOP_EX_BEGIN();
const int tid = omp_get_thread_num();
Expand Down Expand Up @@ -1465,7 +1465,7 @@ int LGBM_DatasetCreateFromCSRFunc(void* get_row_funptr,

OMP_INIT_EX();
std::vector<std::pair<int, double>> thread_buffer;
#pragma omp parallel for schedule(static) private(thread_buffer)
#pragma omp parallel for num_threads(OMP_NUM_THREADS()) schedule(static) private(thread_buffer)
for (int i = 0; i < num_rows; ++i) {
OMP_LOOP_EX_BEGIN();
{
Expand Down Expand Up @@ -1506,7 +1506,7 @@ int LGBM_DatasetCreateFromCSC(const void* col_ptr,
std::vector<std::vector<double>> sample_values(ncol_ptr - 1);
std::vector<std::vector<int>> sample_idx(ncol_ptr - 1);
OMP_INIT_EX();
#pragma omp parallel for schedule(static)
#pragma omp parallel for num_threads(OMP_NUM_THREADS()) schedule(static)
for (int i = 0; i < static_cast<int>(sample_values.size()); ++i) {
OMP_LOOP_EX_BEGIN();
CSC_RowIterator col_it(col_ptr, col_ptr_type, indices, data, data_type, ncol_ptr, nelem, i);
Expand Down Expand Up @@ -1534,7 +1534,7 @@ int LGBM_DatasetCreateFromCSC(const void* col_ptr,
reinterpret_cast<const Dataset*>(reference));
}
OMP_INIT_EX();
#pragma omp parallel for schedule(static)
#pragma omp parallel for num_threads(OMP_NUM_THREADS()) schedule(static)
for (int i = 0; i < ncol_ptr - 1; ++i) {
OMP_LOOP_EX_BEGIN();
const int tid = omp_get_thread_num();
Expand Down

0 comments on commit 8ed371c

Please sign in to comment.