Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Tabular] Remove TorchThreadManager in TabularNN and hack in LightGBM #2472

Merged
merged 2 commits into from
Jan 17, 2023

Conversation

liangfu
Copy link
Collaborator

@liangfu liangfu commented Nov 23, 2022

Issue #, if available:

Description of changes:

  1. Predict tabular NN without modifying num_thread, since we are upgrade xgboost to >=1.6
  2. removed the hack that changes num_thread in lightgbm, since the bug (Passing default num_threads to booster.predict not working microsoft/LightGBM#4607) in lightgbm has been fixed
    a. See original topic: Ray parallel #1329 (comment)

For an ensemble model with both TabularFastAI model and XGBoost model, the overall inference speed is reduced from 49 ms to 33 ms (batch size=1). Because of the removal of TorchThreadManager, XGBoost prediction time reduced from 16 ms to 1.6 ms. (Tested on Linux with AdultIncomeBinaryClassification task trained with high_quality preset.)

Code snippet to reproduce the results:

import time
from autogluon.tabular import TabularDataset, TabularPredictor

# Training time:
label = 'class'  # specifies which column do we want to predict
save_path_clone_opt = '~/Downloads/AdultIncomeBinaryClassificationModel_0_6_2'

predictor = TabularPredictor.load(save_path_clone_opt)

# Inference time:
test_data = TabularDataset('https://autogluon.s3.amazonaws.com/datasets/Inc/test.csv')  # another Pandas DataFrame
test_data = test_data.head(1)
test_data = test_data.drop(labels=[label], axis=1)  # delete labels from test data since we wouldn't have them in practice

predictor.persist_models()
for _ in range(10):
    tic = time.time()
    y_pred = predictor.predict(test_data)
    print(f"elapsed: {(time.time()-tic)*1000:.0f} ms")

By submitting this pull request, I confirm that you can use, modify, copy, and redistribute this contribution, under the terms of your choice.

@github-actions
Copy link

Job PR-2472-c4e3e95 is done.
Docs are uploaded to http://autogluon-staging.s3-website-us-west-2.amazonaws.com/PR-2472/c4e3e95/index.html

@liangfu liangfu changed the title [Tabular] Speedup single-batch tabular NN prediction [Tabular] Removing TorchThreadManager due to updated XGBoost version Dec 7, 2022
@liangfu
Copy link
Collaborator Author

liangfu commented Dec 7, 2022

To verify the difference, here we take a few datasets for a simple experiment on macOS.

With baseline implement in current master branch, we get

        name                task_name    test_score  time_fit  time_predict              eval_metric    test_error  fold  repeat  sample  task_id problem_type
0   NN_TORCH                     wilt      0.955408  3.294026      0.011839                  roc_auc      0.044592     0       0       0   146820       binary
1   NN_TORCH                 credit-g      0.750952  4.086180      0.025251                  roc_auc      0.249048     0       0       0   168757       binary
2   NN_TORCH                  jasmine      0.852841  2.429345      0.045434                  roc_auc      0.147159     0       0       0   168911       binary
3   NN_TORCH                 madeline      0.666951  4.141213      0.013555                  roc_auc      0.333049     0       0       0   190392       binary
4   NN_TORCH               eucalyptus     -0.654584  2.472429      0.019346                 log_loss      0.654584     0       0       0   359954   multiclass
5   NN_TORCH              qsar-biodeg      0.946825  1.641296      0.026339                  roc_auc      0.053175     0       0       0   359956       binary
6   NN_TORCH                      pc4      0.913628  1.199499      0.025022                  roc_auc      0.086372     0       0       0   359958       binary
7   NN_TORCH                      kc1      0.812064  4.014099      0.018316                  roc_auc      0.187936     0       0       0   359962       binary
8   NN_TORCH                  segment     -0.202407  5.373858      0.015987                 log_loss      0.202407     0       0       0   359963   multiclass
9   NN_TORCH  Internet-Advertisements      0.967777  6.372126      0.384555                  roc_auc      0.032223     0       0       0   359966       binary
10  NN_TORCH                  sylvine      0.963765  5.299525      0.020255                  roc_auc      0.036235     0       0       0   359972       binary
11  NN_TORCH                Satellite      0.915933  8.992411      0.008237                  roc_auc      0.084067     0       0       0   359975       binary
12  NN_TORCH                  tecator     -1.502424  1.430726      0.013959  root_mean_squared_error      1.502424     0       0       0   359934   regression
13  NN_TORCH      MIP-2016-regression -23299.954074  2.555077      0.067798  root_mean_squared_error  23299.954074     0       0       0   359947   regression
14  NN_TORCH                   boston     -2.874179  2.369670      0.018515  root_mean_squared_error      2.874179     0       0       0   359950   regression
15       XGB                     wilt      0.987403  0.390685      0.009557                  roc_auc      0.012597     0       0       0   146820       binary
16       XGB                 credit-g      0.832381  0.389215      0.022072                  roc_auc      0.167619     0       0       0   168757       binary
17       XGB                  jasmine      0.867517  1.065899      0.057437                  roc_auc      0.132483     0       0       0   168911       binary
18       XGB                 madeline      0.893582  4.085552      0.016688                  roc_auc      0.106418     0       0       0   190392       binary
19       XGB               eucalyptus     -0.752650  1.595121      0.018889                 log_loss      0.752650     0       0       0   359954   multiclass
20       XGB              qsar-biodeg      0.946429  0.532794      0.013417                  roc_auc      0.053571     0       0       0   359956       binary
21       XGB                      pc4      0.929253  0.532072      0.011503                  roc_auc      0.070747     0       0       0   359958       binary
22       XGB                      kc1      0.800454  0.495211      0.009311                  roc_auc      0.199546     0       0       0   359962       binary
23       XGB                  segment     -0.173008  2.183501      0.010676                 log_loss      0.173008     0       0       0   359963   multiclass
24       XGB  Internet-Advertisements      0.960492  6.906823      0.429402                  roc_auc      0.039508     0       0       0   359966       binary
25       XGB                  sylvine      0.979876  0.659569      0.009840                  roc_auc      0.020124     0       0       0   359972       binary
26       XGB                Satellite      0.988924  0.411049      0.009124                  roc_auc      0.011076     0       0       0   359975       binary
27       XGB                  tecator     -1.346267  1.096021      0.009829  root_mean_squared_error      1.346267     0       0       0   359934   regression
28       XGB      MIP-2016-regression  -1011.745554  3.254313      0.024803  root_mean_squared_error   1011.745554     0       0       0   359947   regression
29       XGB                   boston     -2.569359  0.457009      0.017173  root_mean_squared_error      2.569359     0       0       0   359950   regression

After removing TorchThreadManager

        name                task_name    test_score  time_fit  time_predict              eval_metric    test_error  fold  repeat  sample  task_id problem_type
0   NN_TORCH                     wilt      0.955408  3.152081      0.011707                  roc_auc      0.044592     0       0       0   146820       binary
1   NN_TORCH                 credit-g      0.750952  3.981422      0.024591                  roc_auc      0.249048     0       0       0   168757       binary
2   NN_TORCH                  jasmine      0.852841  2.386335      0.046799                  roc_auc      0.147159     0       0       0   168911       binary
3   NN_TORCH                 madeline      0.666951  4.174943      0.014505                  roc_auc      0.333049     0       0       0   190392       binary
4   NN_TORCH               eucalyptus     -0.654584  2.639098      0.019309                 log_loss      0.654584     0       0       0   359954   multiclass
5   NN_TORCH              qsar-biodeg      0.946825  1.637035      0.026693                  roc_auc      0.053175     0       0       0   359956       binary
6   NN_TORCH                      pc4      0.913628  1.211320      0.029044                  roc_auc      0.086372     0       0       0   359958       binary
7   NN_TORCH                      kc1      0.812064  4.110126      0.018689                  roc_auc      0.187936     0       0       0   359962       binary
8   NN_TORCH                  segment     -0.202407  5.945512      0.016602                 log_loss      0.202407     0       0       0   359963   multiclass
9   NN_TORCH  Internet-Advertisements      0.967777  6.435192      0.396387                  roc_auc      0.032223     0       0       0   359966       binary
10  NN_TORCH                  sylvine      0.963765  5.522072      0.017301                  roc_auc      0.036235     0       0       0   359972       binary
11  NN_TORCH                Satellite      0.915933  8.678970      0.007628                  roc_auc      0.084067     0       0       0   359975       binary
12  NN_TORCH                  tecator     -1.502424  1.390313      0.013227  root_mean_squared_error      1.502424     0       0       0   359934   regression
13  NN_TORCH      MIP-2016-regression -23299.954074  2.692366      0.062395  root_mean_squared_error  23299.954074     0       0       0   359947   regression
14  NN_TORCH                   boston     -2.874179  2.298750      0.015002  root_mean_squared_error      2.874179     0       0       0   359950   regression
15       XGB                     wilt      0.987403  0.364779      0.009304                  roc_auc      0.012597     0       0       0   146820       binary
16       XGB                 credit-g      0.832381  0.374108      0.021996                  roc_auc      0.167619     0       0       0   168757       binary
17       XGB                  jasmine      0.867517  1.012889      0.057122                  roc_auc      0.132483     0       0       0   168911       binary
18       XGB                 madeline      0.893582  3.720246      0.016597                  roc_auc      0.106418     0       0       0   190392       binary
19       XGB               eucalyptus     -0.752650  1.649892      0.023875                 log_loss      0.752650     0       0       0   359954   multiclass
20       XGB              qsar-biodeg      0.946429  0.561978      0.013795                  roc_auc      0.053571     0       0       0   359956       binary
21       XGB                      pc4      0.929253  0.618598      0.013108                  roc_auc      0.070747     0       0       0   359958       binary
22       XGB                      kc1      0.800454  0.541194      0.010329                  roc_auc      0.199546     0       0       0   359962       binary
23       XGB                  segment     -0.173008  2.373941      0.011351                 log_loss      0.173008     0       0       0   359963   multiclass
24       XGB  Internet-Advertisements      0.960492  6.785063      0.481112                  roc_auc      0.039508     0       0       0   359966       binary
25       XGB                  sylvine      0.979876  0.702935      0.010390                  roc_auc      0.020124     0       0       0   359972       binary
26       XGB                Satellite      0.988924  0.402077      0.010818                  roc_auc      0.011076     0       0       0   359975       binary
27       XGB                  tecator     -1.346267  1.015031      0.011103  root_mean_squared_error      1.346267     0       0       0   359934   regression
28       XGB      MIP-2016-regression  -1011.745554  2.904379      0.023663  root_mean_squared_error   1011.745554     0       0       0   359947   regression
29       XGB                   boston     -2.569359  0.457170      0.014724  root_mean_squared_error      2.569359     0       0       0   359950   regression

I would conclude that we won't have a significant difference in terms of performance after removing reset-thread statement.

Here are more visualized details.

image

@github-actions
Copy link

Job PR-2472-2d50e04 is done.
Docs are uploaded to http://autogluon-staging.s3-website-us-west-2.amazonaws.com/PR-2472/2d50e04/index.html

@github-actions
Copy link

Job PR-2472-b670121 is done.
Docs are uploaded to http://autogluon-staging.s3-website-us-west-2.amazonaws.com/PR-2472/b670121/index.html

@liangfu liangfu changed the title [Tabular] Removing TorchThreadManager due to updated XGBoost version [Tabular] Remove TorchThreadManager in TabularNN and hack in LightGBM Jan 12, 2023
@liangfu liangfu requested review from yinweisu and Innixma and removed request for yinweisu January 12, 2023 05:50
@Innixma
Copy link
Contributor

Innixma commented Jan 17, 2023

Analysis on AdultIncome shows a major speedup on m6i.16xlarge for small-batch inference at ~2.5x faster!

Mainline
------------------------
batch_size=1
Average: 4.5 ms (LightGBM_BAG_L1_FULL)
Average: 3.2 ms (CatBoost_BAG_L1_FULL)
Average: 11.7 ms (NeuralNetFastAI_BAG_L1_FULL)
Average: 13.7 ms (XGBoost_BAG_L1_FULL)
Average: 59.0 ms (WeightedEnsemble_L2_FULL)
------------------------
batch_size=100
Average: 4.2 ms (LightGBM_BAG_L1_FULL)
Average: 3.5 ms (CatBoost_BAG_L1_FULL)
Average: 11.1 ms (NeuralNetFastAI_BAG_L1_FULL)
Average: 6.3 ms (XGBoost_BAG_L1_FULL)
Average: 56.3 ms (WeightedEnsemble_L2_FULL)
------------------------
batch_size=10000
Average: 21.6 ms (LightGBM_BAG_L1_FULL)
Average: 21.7 ms (CatBoost_BAG_L1_FULL)
Average: 120.0 ms (NeuralNetFastAI_BAG_L1_FULL)
Average: 36.5 ms (XGBoost_BAG_L1_FULL)
Average: 163.7 ms (WeightedEnsemble_L2_FULL)

This PR
------------------------
batch_size=1
Average: 3.9 ms (LightGBM_BAG_L1_FULL)
Average: 3.3 ms (CatBoost_BAG_L1_FULL)
Average: 10.8 ms (NeuralNetFastAI_BAG_L1_FULL)
Average: 3.7 ms (XGBoost_BAG_L1_FULL)
Average: 21.2 ms (WeightedEnsemble_L2_FULL)
------------------------
batch_size=100
Average: 4.0 ms (LightGBM_BAG_L1_FULL)
Average: 3.7 ms (CatBoost_BAG_L1_FULL)
Average: 9.3 ms (NeuralNetFastAI_BAG_L1_FULL)
Average: 3.7 ms (XGBoost_BAG_L1_FULL)
Average: 20.9 ms (WeightedEnsemble_L2_FULL)
------------------------
batch_size=10000
Average: 23.5 ms (LightGBM_BAG_L1_FULL)
Average: 22.7 ms (CatBoost_BAG_L1_FULL)
Average: 114.4 ms (NeuralNetFastAI_BAG_L1_FULL)
Average: 34.9 ms (XGBoost_BAG_L1_FULL)
Average: 156.8 ms (WeightedEnsemble_L2_FULL)

Copy link
Contributor

@Innixma Innixma left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, great work!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

2 participants