Skip to content

Commit

Permalink
Fix calculation of number of bins in FindGroup (#6019)
Browse files Browse the repository at this point in the history
* solve 'bin size 257 cannot run on GPU #3339'

#3339 (comment)

* fix  typo LeafIndex -> leaf_index

---------

Co-authored-by: shiyu1994 <shiyu_k1994@qq.com>
Co-authored-by: James Lamb <jaylamb20@gmail.com>
  • Loading branch information
3 people committed Feb 20, 2024
1 parent 45a60a7 commit d0d7071
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 3 deletions.
4 changes: 2 additions & 2 deletions src/io/dataset.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ std::vector<std::vector<int>> FindGroups(
std::vector<int> available_groups;
for (int gid = 0; gid < static_cast<int>(features_in_group.size()); ++gid) {
auto cur_num_bin = group_num_bin[gid] + bin_mappers[fidx]->num_bin() +
(bin_mappers[fidx]->GetDefaultBin() == 0 ? -1 : 0);
(bin_mappers[fidx]->GetMostFreqBin() == 0 ? -1 : 0);
if (group_total_data_cnt[gid] + cur_non_zero_cnt <=
total_sample_cnt + single_val_max_conflict_cnt) {
if (!is_use_gpu || cur_num_bin <= max_bin_per_group) {
Expand Down Expand Up @@ -189,7 +189,7 @@ std::vector<std::vector<int>> FindGroups(
group_used_row_cnt.emplace_back(cur_non_zero_cnt);
group_num_bin.push_back(
1 + bin_mappers[fidx]->num_bin() +
(bin_mappers[fidx]->GetDefaultBin() == 0 ? -1 : 0));
(bin_mappers[fidx]->GetMostFreqBin() == 0 ? -1 : 0));
}
}
if (!is_sparse) {
Expand Down
2 changes: 1 addition & 1 deletion src/treelearner/gpu_tree_learner.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1089,7 +1089,7 @@ void GPUTreeLearner::FindBestSplits(const Tree* tree) {
size_t bin_size = train_data_->FeatureNumBin(feature_index) + 1;
printf("Feature %d smaller leaf:\n", feature_index);
PrintHistograms(smaller_leaf_histogram_array_[feature_index].RawData() - kHistOffset, bin_size);
if (larger_leaf_splits_ == nullptr || larger_leaf_splits_->LeafIndex() < 0) { continue; }
if (larger_leaf_splits_ == nullptr || larger_leaf_splits_->leaf_index() < 0) { continue; }
printf("Feature %d larger leaf:\n", feature_index);
PrintHistograms(larger_leaf_histogram_array_[feature_index].RawData() - kHistOffset, bin_size);
}
Expand Down

0 comments on commit d0d7071

Please sign in to comment.