Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ENH Decision Tree new backend computeSplitClassificationKernel histogram calculation and occupancy optimization #3616

Conversation

venkywonka
Copy link
Contributor

@venkywonka venkywonka commented Mar 15, 2021

  • This PR introduces:
    • A faster way to calculate the histograms containing splits in the ML::DecisionTree::computeSplitClassificationKernel . These histograms are used for node-splitting in decision trees for the task of classification.
    • A change in the default gridDim.x in the launch configuration of the above kernel from 4 to based on occupancy calculator and other dimension gridDims, thus improving the occupancy to theoretical limits
  • Earlier too many atomic adds to shared memory limited the kernel times, which has been avoided by blockwide sum-scans to obtain the same histogram using fewer atomic writes to shared memory.
  • The resulting kernel time speedups are significant (upto 30x for some nodes)
  • computeSplitRegressionKernel has different share-memory write patterns that deserves it's own PR for optimization 😬
  • Tests will pass once BUG fix BatchedLevelAlgo DtClsTest & DtRegTest failing tests #3690 is merged

@venkywonka venkywonka requested a review from a team as a code owner March 15, 2021 15:04
@venkywonka venkywonka added improvement Improvement / enhancement to an existing function non-breaking Non-breaking change Perf Related to runtime performance of the underlying code labels Mar 15, 2021
@dantegd dantegd added the 4 - Waiting on Author Waiting for author to respond to review label Mar 16, 2021
}
// case when d is larger than all bins
if(!breakflag) atomicAdd(pdf_shist + nbins*nclasses + label, 1);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should never happen with new way of quantile computation once #3586 gets merged. As it does not incur a minimal penalty, we can keep this for now.

Comment on lines 456 to 460
int n_blks_for_rows = b.n_blks_for_rows(
colBlks,
(const void*)
computeSplitClassificationKernel<DataT, LabelT, IdxT, TPB_DEFAULT>,
TPB_DEFAULT, smemSize);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This value is computed again in workspaceSize. Is there a guarantee that calling n_blks_for_rows() would result in a consistent ouput?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes it's value is basically constant, and equals the number of blocks for dimx needed when dimz (parallel nodes) is minimum (which is 1). This happens to be a function of n_blks_for_cols and occupancy calculator. As far as occupancy is concerned, the theoretical limiter for now is register count. So ceterus paribus, the output of n_blks_for_rows() should also be the same when called from workspaceSize and one preceding call of computeSplit.*Kernel

@@ -362,50 +367,132 @@ __global__ void computeSplitClassificationKernel(
col = select(colIndex, treeid, node.info.unique_id, seed, input.N);
}

for (IdxT i = threadIdx.x; i < len; i += blockDim.x) shist[i] = 0;
// populating shared memory with initial values
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for adding the comments. It helps me understand the code better.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

😄

auto isRight = d > sbins[b]; // no divergence
auto offset = b * 2 * nclasses + isRight * nclasses + label;
atomicAdd(shist + offset, 1); // class hist
if (d <= sbins[b]) { // shist (0 -> nbins*nclasses - 1)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is going to cause warp divergence, unlike the old code. What's the rationale for this change?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IIUC, this seems to be a small price to pay for reducing smem atomic writes by an order of magitude...
So, previous code does atomicAdd() every single iteration (O(n^2)) . The change basically does it only once per outer data-sample and then breaks (O(n)).

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@hcho3 moreover, when we get #3606 I think this divergence will just be eliminated, even.

@hcho3
Copy link
Contributor

hcho3 commented Mar 18, 2021

I tested this pull request last night and observed a significant performance regression. Here is a minimal reproducible example:

import os
import pickle
import tqdm
import time
from sklearn.model_selection import train_test_split
import pandas as pd
import numpy as np
from enum import Enum
from urllib.request import urlretrieve
from cuml.ensemble import RandomForestClassifier as cumlRandomForestClassifier
from cuml.ensemble import RandomForestRegressor as cumlRandomForestRegressor

pbar = None

class LearningTask(Enum):
    REGRESSION = 1
    CLASSIFICATION = 2
    MULTICLASS_CLASSIFICATION = 3

class Data:  # pylint: disable=too-few-public-methods,too-many-arguments
    def __init__(self, X_train, X_test, y_train, y_test, learning_task):
        self.X_train = X_train
        self.X_test = X_test
        self.y_train = y_train
        self.y_test = y_test
        self.learning_task = learning_task

def show_progress(block_num, block_size, total_size):
    global pbar
    if pbar is None:
        pbar = tqdm.tqdm(total=total_size / 1024, unit='kB')

    downloaded = block_num * block_size
    if downloaded < total_size:
        pbar.update(block_size / 1024)
    else:
        pbar.close()
        pbar = None

def retrieve(url, filename=None):
    return urlretrieve(url, filename, reporthook=show_progress)

def download_higgs():
    url = 'https://archive.ics.uci.edu/ml/machine-learning-databases/00280/HIGGS.csv.gz'
    local_url = os.path.basename(url)
    pickle_url = "higgs.pkl"

    if os.path.exists(pickle_url):
        return pickle.load(open(pickle_url, "rb"))

    if not os.path.isfile(local_url):
        retrieve(url, local_url)
    higgs = pd.read_csv(local_url)
    X = higgs.iloc[:, 1:].to_numpy(dtype=np.float32)
    y = higgs.iloc[:, 0].to_numpy(dtype=np.float32)
    X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=77,
                                                        test_size=0.2)
    data = Data(X_train, X_test, y_train, y_test, LearningTask.CLASSIFICATION)
    pickle.dump(data, open(pickle_url, "wb"), protocol=4)
    return data

def download_year():
    url = 'https://archive.ics.uci.edu/ml/machine-learning-databases/00203/YearPredictionMSD.txt' \
          '.zip'
    local_url = os.path.basename(url)
    pickle_url = "year.pkl"

    if os.path.exists(pickle_url):
        return pickle.load(open(pickle_url, "rb"))

    if not os.path.isfile(local_url):
        retrieve(url, local_url)
    year = pd.read_csv(local_url, header=None)
    X = year.iloc[:, 1:].to_numpy(dtype=np.float32)
    y = year.iloc[:, 0].to_numpy(dtype=np.float32)

    X_train, X_test, y_train, y_test = train_test_split(X, y, shuffle=False,
                                                        train_size=463715,
                                                        test_size=51630)

    data = Data(X_train, X_test, y_train, y_test, LearningTask.REGRESSION)
    pickle.dump(data, open(pickle_url, "wb"), protocol=4)
    return data

def main():
    higgs = download_higgs()
    year = download_year()
    
    # higgs
    tstart = time.perf_counter()
    clf = cumlRandomForestClassifier(max_features=1.0, random_state=0, n_bins=128, n_streams=4,
                                     bootstrap=True, n_estimators=100, max_depth=20,
                                     max_samples=0.01, split_algo=1, use_experimental_backend=True)
    clf.fit(higgs.X_train, higgs.y_train)
    tend = time.perf_counter()
    print(f'higgs: time elapsed = {tend - tstart} s')

    # year
    tstart = time.perf_counter()
    clf = cumlRandomForestRegressor(max_features=1.0, random_state=0, n_bins=128, n_streams=4,
                                    bootstrap=True, n_estimators=100, max_depth=20,
                                    max_samples=0.01, split_algo=1, use_experimental_backend=True)
    clf.fit(year.X_train, year.y_train)
    tend = time.perf_counter()
    print(f'year: time elapsed = {tend - tstart} s')

if __name__ == '__main__':
    main()

(The datasets get downloaded once and are cached in subsequent runs.)

Before (commit 14bd6c1):

mre.py:91: UserWarning: For reproducible results in Random Forest Classifier or for almost reproducible results in Random Forest Regressor, n_streams==1 is recommended. If n_streams is > 1, results may vary due to stream/thread timing differences, even when random_state is set
  clf = cumlRandomForestClassifier(max_features=1.0, random_state=0, n_bins=128, n_streams=4,
[W] [11:56:14.285130] Using experimental backend for growing trees

higgs: time elapsed = 11.467347843979951 s
mre.py:100: UserWarning: For reproducible results in Random Forest Classifier or for almost reproducible results in Random Forest Regressor, n_streams==1 is recommended. If n_streams is > 1, results may vary due to stream/thread timing differences, even when random_state is set
  clf = cumlRandomForestRegressor(max_features=1.0, random_state=0, n_bins=128, n_streams=4,
[W] [11:56:21.925520] Using experimental backend for growing trees

year: time elapsed = 58.9239871900063 s

After (this PR, with git merge branch-0.19):

mre.py:91: UserWarning: For reproducible results in Random Forest Classifier or for almost reproducible results in Random Forest Regressor, n_streams==1 is recommended. If n_streams is > 1, results may vary due to stream/thread timing differences, even when random_state is set
  clf = cumlRandomForestClassifier(max_features=1.0, random_state=0, n_bins=128, n_streams=4,
[W] [11:49:00.851581] Using experimental backend for growing trees

higgs: time elapsed = 17.181643451040145 s
mre.py:100: UserWarning: For reproducible results in Random Forest Classifier or for almost reproducible results in Random Forest Regressor, n_streams==1 is recommended. If n_streams is > 1, results may vary due to stream/thread timing differences, even when random_state is set
  clf = cumlRandomForestRegressor(max_features=1.0, random_state=0, n_bins=128, n_streams=4,
[W] [11:49:14.191898] Using experimental backend for growing trees

year: time elapsed = 277.97470924502704 s

My GPU is Quadro RTX 8000, using CUDA 11.0 and deriver 450.51.06.

@venkywonka venkywonka requested review from a team as code owners March 23, 2021 13:07
@github-actions github-actions bot added CMake conda conda issue Cython / Python Cython or Python issue labels Mar 23, 2021
Copy link
Member

@ajschmidt8 ajschmidt8 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Approving ops-codeowner file changes

@@ -43,7 +43,7 @@ requirements:
- libcumlprims {{ minor_version }}
- cupy>=7.8.0,<9.0.0a0
- treelite=1.0.0
- nccl>=2.5
- nccl>=2.8.4
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you please leave out unrelated changes from this pull request?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sorry about that, don't know why i thought rebasing was a good idea!

@venkywonka venkywonka force-pushed the enh-ext-histogram-calculation-optimization-for-computesplitclassificationkernel branch from be18c99 to c128de6 Compare March 24, 2021 01:44
@github-actions github-actions bot removed the Cython / Python Cython or Python issue label Mar 24, 2021
@venkywonka venkywonka force-pushed the enh-ext-histogram-calculation-optimization-for-computesplitclassificationkernel branch from c128de6 to 32a773f Compare March 30, 2021 06:02
@github-actions github-actions bot added CMake conda conda issue Cython / Python Cython or Python issue gpuCI gpuCI issue labels Mar 30, 2021
…stogram-calculation-optimization-for-computesplitclassificationkernel
@github-actions github-actions bot removed Cython / Python Cython or Python issue conda conda issue CMake gpuCI gpuCI issue labels Mar 30, 2021
Copy link
Member

@teju85 teju85 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Changes LGTM. Thanks @venkywonka for this PR!

@teju85
Copy link
Member

teju85 commented Mar 30, 2021

@JohnZed or @dantegd somehow this PR is also expecting a python approval too!? Can we get one please?

@teju85
Copy link
Member

teju85 commented Mar 31, 2021

@venkywonka seems like batched-level-algo unit-tests are now failing for 11.2?

@venkywonka
Copy link
Contributor Author

@venkywonka seems like batched-level-algo unit-tests are now failing for 11.2?

yes, working on it 😅 yet to find out why

@venkywonka
Copy link
Contributor Author

@venkywonka seems like batched-level-algo unit-tests are now failing for 11.2?

yes, working on it 😅 yet to find out why

The tests for this PR and PR #3674 will pass once PR #3690 gets merged

@JohnZed JohnZed added this to PR-WIP in v0.19 Release via automation Apr 1, 2021
rapids-bot bot pushed a commit that referenced this pull request Apr 1, 2021
* This PR fixes the regressions shown by `BatchedLevelAlgo/DtClsTestF` and `BatchedLevelAlgo/DtRegTestF` wherein the quantiles parameter passed to `grow_tree` function was uninitialized garbage memory as opposed to what should have been quantiles computed for each column. 
* It also replaces the old method of computing quantiles (`preprocess_quantiles`) with new, more accurate one (`computeQuantiles`)
* removes an unnecessary memory allocation to `tempmem` in the setup phase of the test fixture.
* This fixes failing `BatchedLevelAlgo/DtRegTestF` tests as reported in issue #3406 
* It also fixes failing `BatchedLevelAlgo/DtClsTestF` tests in PR #3616

cc @teju85 @vinaydes @JohnZed @hcho3

Authors:
  - Venkat (https://github.com/venkywonka)

Approvers:
  - Thejaswi. N. S (https://github.com/teju85)
  - John Zedlewski (https://github.com/JohnZed)

URL: #3690
v0.19 Release automation moved this from PR-WIP to PR-Reviewer approved Apr 2, 2021
…ogram-calculation-optimization-for-computesplitclassificationkernel
@codecov-io
Copy link

Codecov Report

Merging #3616 (942968c) into branch-0.19 (fd9ec89) will increase coverage by 2.21%.
The diff coverage is n/a.

Impacted file tree graph

@@               Coverage Diff               @@
##           branch-0.19    #3616      +/-   ##
===============================================
+ Coverage        80.70%   82.92%   +2.21%     
===============================================
  Files              227      227              
  Lines            17615    17591      -24     
===============================================
+ Hits             14217    14587     +370     
+ Misses            3398     3004     -394     
Flag Coverage Δ
dask 45.31% <ø> (+0.32%) ⬆️
non-dask 74.95% <ø> (+2.03%) ⬆️

Flags with carried forward coverage won't be shown. Click here to find out more.

Impacted Files Coverage Δ
python/cuml/__init__.py 95.58% <ø> (+0.20%) ⬆️
...cuml/_thirdparty/sklearn/preprocessing/__init__.py 100.00% <ø> (ø)
...on/cuml/_thirdparty/sklearn/preprocessing/_data.py 64.27% <ø> (+1.16%) ⬆️
...hirdparty/sklearn/preprocessing/_discretization.py 83.33% <ø> (-0.88%) ⬇️
...l/_thirdparty/sklearn/preprocessing/_imputation.py 85.54% <ø> (+22.74%) ⬆️
python/cuml/_thirdparty/sklearn/utils/extmath.py 56.89% <ø> (ø)
...cuml/_thirdparty/sklearn/utils/skl_dependencies.py 79.54% <ø> (+25.62%) ⬆️
...ython/cuml/_thirdparty/sklearn/utils/validation.py 18.41% <ø> (-4.04%) ⬇️
python/cuml/cluster/__init__.py 100.00% <ø> (ø)
python/cuml/cluster/agglomerative.pyx 96.47% <ø> (ø)
... and 152 more

Continue to review full report at Codecov.

Legend - Click here to learn more
Δ = absolute <relative> (impact), ø = not affected, ? = missing data
Powered by Codecov. Last update 1554f14...942968c. Read the comment docs.

@teju85
Copy link
Member

teju85 commented Apr 4, 2021

@gpucibot merge

@rapids-bot rapids-bot bot merged commit 4bf0ba4 into rapidsai:branch-0.19 Apr 4, 2021
v0.19 Release automation moved this from PR-Reviewer approved to Done Apr 4, 2021
rapids-bot bot pushed a commit that referenced this pull request Apr 6, 2021
…tion optimization (#3674)

This is a follow-up of PR #3616 and should be merged after that.
This PR introduces:
* Modularizing the `pdf_to_cdf` conversion using inclusive-sumscan into a device function so that it can be reused by both the `ML::DecisionTree::computeSplitClassificationKernel` and `ML::DecisionTree::computeSplitRegressionKernel`
* Integrating the above mentioned device function to calculate the prediction sums and counts in the `ML::DecisionTree::computeSplitRegressionKernel` . These histograms are used for node-splitting in decision trees for the task of regression.
* The reason for this optimization follows the same explanation given in PR #3616 
* As of now, only the first pass has been optimized using sumscans.

Authors:
  - Venkat (https://github.com/venkywonka)

Approvers:
  - Philip Hyunsu Cho (https://github.com/hcho3)
  - Thejaswi. N. S (https://github.com/teju85)

URL: #3674
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
4 - Waiting on Author Waiting for author to respond to review improvement Improvement / enhancement to an existing function non-breaking Non-breaking change Perf Related to runtime performance of the underlying code
Projects
No open projects
v0.19 Release
  
Done
Development

Successfully merging this pull request may close these issues.

None yet

8 participants