# 1. Create data:

In [1]:
from sklearn.datasets import make_classification, make_regression
Xcls, ycls = make_classification(n_samples=200, n_features=20, n_classes=3, random_state=42, n_clusters_per_class=3, n_informative=5)
Xreg, yreg = make_regression(n_samples=200, n_features=20, random_state=42)

# 2. Create and Fit RandomForest and GradientBoosting models for Regression and Classification

In [2]:
import ntree_tuning as ntt

rf_cls = ntt.Ntree_RandForest_Classifier(n_estimators=100)
rf_cls.fit(Xcls, ycls)

rf_reg = ntt.Ntree_RandForest_Regressor(n_estimators=100)
rf_reg.fit(Xreg, yreg)

gb_cls = ntt.Ntree_GradBoost_Classifier(n_estimators=100, subsample=0.8)
gb_cls.fit(Xcls, ycls)

gb_reg = ntt.Ntree_GradBoost_Regressor(n_estimators=100, subsample=0.8)
gb_reg.fit(Xreg, yreg)

# 3. Tune ntree

Depending on whether you have a RF or GB model you can now call tune_ntree_rf or tune_ntree_gb to get the oob_error dicts for the specified values of ntrees

In [None]:
gb_reg.tune_ntrees()

{0: 15368.289464850377,
 1: 15692.152619371325,
 2: 19280.115243613312,
 3: 14177.668510671176,
 4: 13207.571136518896,
 5: 14399.127968009814,
 6: 12739.766428943698,
 7: 11181.291537212383,
 8: 10996.459394621528,
 9: 5935.239839333188,
 10: 9003.533773195279,
 11: 6595.96419620331,
 12: 5065.344678955262,
 13: 4827.5120846299415,
 14: 4940.668000903138,
 15: 4864.72476373683,
 16: 3647.0387848793303,
 17: 5756.851764675664,
 18: 4044.177476596138,
 19: 3253.139711828321,
 20: 4266.356331808298,
 21: 3405.0076091254205,
 22: 3852.169661357507,
 23: 2941.0786546219815,
 24: 2217.3921969060057,
 25: 3123.8610256968404,
 26: 2720.8943503867376,
 27: 2141.1610924345027,
 28: 2265.8482274704156,
 29: 2256.32528807266,
 30: 1664.451233976793,
 31: 1593.7797449798277,
 32: 1331.4842820736778,
 33: 1145.074911527719,
 34: 1086.374570260314,
 35: 1133.1343990057583,
 36: 900.9990321319083,
 37: 963.1611121833939,
 38: 670.5274266160728,
 39: 612.5569783760851,
 40: 836.4310839703594,
 41: 729

In [None]:
# Gradient Boosting
print(gb_reg.tune_ntrees())
print(gb_cls.tune_ntrees())


# Random Forests
min_trees = 20
max_trees = 80
delta_trees = 5

print(rf_reg.tune_ntrees(Xreg, yreg, min_trees, max_trees, delta_trees))
print(rf_cls.tune_ntrees(Xcls, ycls, min_trees, max_trees, delta_trees))

{0: 15368.289464850377, 1: 15692.152619371325, 2: 19280.115243613312, 3: 14177.668510671176, 4: 13207.571136518896, 5: 14399.127968009814, 6: 12739.766428943698, 7: 11181.291537212383, 8: 10996.459394621528, 9: 5935.239839333188, 10: 9003.533773195279, 11: 6595.96419620331, 12: 5065.344678955262, 13: 4827.5120846299415, 14: 4940.668000903138, 15: 4864.72476373683, 16: 3647.0387848793303, 17: 5756.851764675664, 18: 4044.177476596138, 19: 3253.139711828321, 20: 4266.356331808298, 21: 3405.0076091254205, 22: 3852.169661357507, 23: 2941.0786546219815, 24: 2217.3921969060057, 25: 3123.8610256968404, 26: 2720.8943503867376, 27: 2141.1610924345027, 28: 2265.8482274704156, 29: 2256.32528807266, 30: 1664.451233976793, 31: 1593.7797449798277, 32: 1331.4842820736778, 33: 1145.074911527719, 34: 1086.374570260314, 35: 1133.1343990057583, 36: 900.9990321319083, 37: 963.1611121833939, 38: 670.5274266160728, 39: 612.5569783760851, 40: 836.4310839703594, 41: 729.2879167985118, 42: 765.760074670273, 43:

In [9]:
gb_reg.predict_ntree(Xcls, 20)

array([  881.888144  ,   -76.96110596,  -668.04154464, -1601.42389268,
         974.28578073,  -477.72528483,   727.58773103,  -224.8226607 ,
       -1646.41203838,  -574.28874036,   497.4270578 ,  1401.98706292,
       -1629.11085483,  1076.79156797,  1379.08380746,   678.38154477,
           2.77368291,  -179.54151065,   766.66771012,    22.65144181,
        2085.73179693,   711.77673791,   151.05979354,   201.94733445,
         -52.2387442 ,  -754.99544305, -1740.78672225,  -594.36545375,
         150.94729919,  -427.15321131,   354.17045166,  1280.23621806,
        1523.23823847,   717.65560759,  -688.47266282,  -462.65471172,
        -873.93249379,   850.99729551,   561.78571462,  1280.35379557,
       -1461.58972863,   449.12577767,  -310.55597836,  1685.51393951,
        -551.17861668,  -195.89852716,   871.43614053,   722.83150948,
        -821.53848636,   368.39072768,  -929.05996481,  1031.84825628,
        -292.12236499,   129.3148564 ,   961.24060782, -1370.4084937 ,
      