Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[REVIEW] ENH Decision Tree new backend computeSplit*Kernel histogram calculation optimization #3674

Conversation

venkywonka
Copy link
Contributor

@venkywonka venkywonka commented Mar 31, 2021

This is a follow-up of PR #3616 and should be merged after that.
This PR introduces:

  • Modularizing the pdf_to_cdf conversion using inclusive-sumscan into a device function so that it can be reused by both the ML::DecisionTree::computeSplitClassificationKernel and ML::DecisionTree::computeSplitRegressionKernel
  • Integrating the above mentioned device function to calculate the prediction sums and counts in the ML::DecisionTree::computeSplitRegressionKernel . These histograms are used for node-splitting in decision trees for the task of regression.
  • The reason for this optimization follows the same explanation given in PR ENH Decision Tree new backend computeSplitClassificationKernel histogram calculation and occupancy optimization #3616
  • As of now, only the first pass has been optimized using sumscans.

  * using atomics to calculate PDFs and then using blockScan to get
  required CDFs that was originally issueing too many atomicAdds to
  shared memory
    * dynamically assigning based on occupancy while ceil-ing it to minimum 4 blocks
    * pruning unnecessary comments and code
    * improving doxygen comments
    * adding some explanatory comments
…stogram-calculation-optimization-for-computesplitclassificationkernel
    * shift the blockscan code to a reusable device function
    * change appropriately in `computeSplitClassificationKernel`
…ogram-calculation-optimization-computeSplitRegressionKernel
@venkywonka venkywonka added improvement Improvement / enhancement to an existing function non-breaking Non-breaking change Perf Related to runtime performance of the underlying code labels Mar 31, 2021
typedef cub::BlockScan<DataT, TPB> BlockScan;
__shared__ typename BlockScan::TempStorage temp_storage;

for (IdxT tix = threadIdx.x; tix < max(TPB, nbins); tix += blockDim.x) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If nbins > TPB, then the resulting scan will be incorrect (because the total sum of the previous iteration is not being carried forward for the next iteration).

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IOW, InclusiveSum provides an option to also get the total sum with this function: https://nvlabs.github.io/cub/classcub_1_1_block_scan.html#a99222ab9b122e6df879ee04b4e8244da

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

😅 Thank you for that, have rectified it 👍🏻

// locations
offset_cdf += nbins;
//convert pdf to cdf
pdf_to_cdf<int, IdxT, TPB>(pdf_shist + offset_pdf, cdf_shist + offset_cdf,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why not just compute this cdf using the cdf from the above and its total sum? That way, we could avoid an extra block-scan operation.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I did think of that initially, but leveraging total-sum from previous sumscan did not strike me! I have added it thejaswi 🙏🏻

…ogram-calculation-optimization-computeSplitRegressionKernel
   * incorporating block_aggregate in inclusive sumscan
   * using total sum instead of doing a right-to-left sumscan
@venkywonka venkywonka changed the title [WIP] ENH Decision Tree new backend computeSplitRegressionKernel histogram calculation optimization [REVIEW] ENH Decision Tree new backend computeSplitRegressionKernel histogram calculation optimization Apr 1, 2021
@venkywonka venkywonka marked this pull request as ready for review April 1, 2021 13:10
@venkywonka venkywonka requested a review from a team as a code owner April 1, 2021 13:10
@venkywonka venkywonka changed the title [REVIEW] ENH Decision Tree new backend computeSplitRegressionKernel histogram calculation optimization [REVIEW] ENH Decision Tree new backend computeSplit*Kernel histogram calculation optimization Apr 1, 2021
…ogram-calculation-optimization-computeSplitRegressionKernel
@codecov-io
Copy link

Codecov Report

Merging #3674 (d4ee98f) into branch-0.19 (fd9ec89) will increase coverage by 2.21%.
The diff coverage is n/a.

Impacted file tree graph

@@               Coverage Diff               @@
##           branch-0.19    #3674      +/-   ##
===============================================
+ Coverage        80.70%   82.92%   +2.21%     
===============================================
  Files              227      227              
  Lines            17615    17591      -24     
===============================================
+ Hits             14217    14587     +370     
+ Misses            3398     3004     -394     
Flag Coverage Δ
dask 45.31% <ø> (+0.32%) ⬆️
non-dask 74.95% <ø> (+2.03%) ⬆️

Flags with carried forward coverage won't be shown. Click here to find out more.

Impacted Files Coverage Δ
python/cuml/__init__.py 95.58% <ø> (+0.20%) ⬆️
...cuml/_thirdparty/sklearn/preprocessing/__init__.py 100.00% <ø> (ø)
...on/cuml/_thirdparty/sklearn/preprocessing/_data.py 64.27% <ø> (+1.16%) ⬆️
...hirdparty/sklearn/preprocessing/_discretization.py 83.33% <ø> (-0.88%) ⬇️
...l/_thirdparty/sklearn/preprocessing/_imputation.py 85.54% <ø> (+22.74%) ⬆️
python/cuml/_thirdparty/sklearn/utils/extmath.py 56.89% <ø> (ø)
...cuml/_thirdparty/sklearn/utils/skl_dependencies.py 79.54% <ø> (+25.62%) ⬆️
...ython/cuml/_thirdparty/sklearn/utils/validation.py 18.41% <ø> (-4.04%) ⬇️
python/cuml/cluster/__init__.py 100.00% <ø> (ø)
python/cuml/cluster/agglomerative.pyx 96.47% <ø> (ø)
... and 152 more

Continue to review full report at Codecov.

Legend - Click here to learn more
Δ = absolute <relative> (impact), ø = not affected, ? = missing data
Powered by Codecov. Last update 1554f14...d4ee98f. Read the comment docs.

@teju85
Copy link
Member

teju85 commented Apr 4, 2021

@venkywonka please resolve conflicts

…stogram-calculation-optimization-computeSplitRegressionKernel
@venkywonka venkywonka force-pushed the enh-ext-histogram-calculation-optimization-computeSplitRegressionKernel branch from 607a926 to 6e39fef Compare April 4, 2021 08:05
Copy link
Contributor

@hcho3 hcho3 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. This pull request produces ~1.5x speed up on a few regression data sets I've tried. It does not impact the classification task.

gbm-bench
(only showing public data sets here)

Copy link
Member

@teju85 teju85 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Changes LGTM.

@JohnZed
Copy link
Contributor

JohnZed commented Apr 6, 2021

@gpucibot merge

@rapids-bot rapids-bot bot merged commit 9feecfb into rapidsai:branch-0.19 Apr 6, 2021
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
improvement Improvement / enhancement to an existing function non-breaking Non-breaking change Perf Related to runtime performance of the underlying code
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

5 participants