# 1. Create data:

In [1]:
from sklearn.datasets import make_classification, make_regression
Xcls, ycls = make_classification(n_samples=200, n_features=20, n_classes=3, random_state=42, n_clusters_per_class=3, n_informative=5)
Xreg, yreg = make_regression(n_samples=200, n_features=20, random_state=42)

# 2. Create and Fit RandomForest and GradientBoosting models for Regression and Classification

In [2]:
import ntrees_tuning as ntt

rf_cls = ntt.Ntree_RandForest_Classifier(n_estimators=100)
rf_cls.fit(Xcls, ycls)

rf_reg = ntt.Ntree_RandForest_Regressor(n_estimators=100)
rf_reg.fit(Xreg, yreg)

gb_cls = ntt.Ntree_GradBoost_Classifier(n_estimators=100, subsample=0.8)
gb_cls.fit(Xcls, ycls)

gb_reg = ntt.Ntree_GradBoost_Regressor(n_estimators=100, subsample=0.8)
gb_reg.fit(Xreg, yreg)

# 3. Tune ntree

You then can call the `tune_ntrees` method to get a dictionary of the pairs of the `ntrees` value and the oob-error.

In [3]:
# Gradient Boosting
print(gb_reg.tune_ntrees())
print(gb_cls.tune_ntrees())


# Random Forests
min_trees = 20
max_trees = 80
delta_trees = 5

print(rf_reg.tune_ntrees(Xreg, yreg, min_trees, max_trees, delta_trees))
print(rf_cls.tune_ntrees(Xcls, ycls, min_trees, max_trees, delta_trees))

{0: 15990.812143269979, 1: 23942.69055811961, 2: 14349.220767587674, 3: 12876.947359727408, 4: 14860.799145739413, 5: 8883.46222051006, 6: 11571.85752257648, 7: 13553.634817504048, 8: 6544.121028825166, 9: 7226.189106444834, 10: 9354.09750247949, 11: 7396.585979205105, 12: 6180.127178420744, 13: 7189.613572895318, 14: 6321.040894098903, 15: 6108.492777410982, 16: 5864.564979226343, 17: 4747.021027921228, 18: 4416.518211293253, 19: 3633.258046189664, 20: 2943.7672415743145, 21: 2670.953595392302, 22: 3013.291399777899, 23: 1894.2719256375835, 24: 2393.40347008514, 25: 1453.8759459736762, 26: 1059.0387442553094, 27: 1826.7725152284602, 28: 2529.4145478331593, 29: 1471.7054603074982, 30: 1742.6141727035908, 31: 1840.536124920495, 32: 1468.791407643776, 33: 769.7246607174854, 34: 1070.730417443406, 35: 1119.0678820086628, 36: 1005.0298097790895, 37: 1015.727633970756, 38: 1044.6867917130367, 39: 998.3590198219852, 40: 895.0961264381152, 41: 991.5768729031184, 42: 882.6735748507683, 43: 579

# 4. Predict with ntrees

In [4]:
print(gb_reg.predict_ntrees(Xreg, ntrees=10))

[  480.08569445   620.21590926   346.44928567  -440.52576623
   489.04021561  -972.59115738  -440.52576623  -397.10807775
   824.57054156    53.62816358   346.44928567  -517.45296163
 -1769.62768851   176.29838001   835.22176328  -115.6085277
  -117.38744185 -1303.20132426   797.5601438   -278.38146476
   458.01674037    86.63496303 -1397.71415899  1732.6235721
   604.98367125   264.89912098  1576.98695729   823.43952867
   369.84212559   817.81820252   -82.42399104   648.67703843
 -1647.32894585  1150.27552682  -166.07839813  -261.16605721
  -230.01896879  -668.89476194  -107.65692433   997.81506555
  -714.11010156 -1035.47543531   648.67703843  -673.68720847
   739.07191223  -635.737446    -626.23856252  -760.69456781
 -1228.92501512  -693.73653018   -94.92605118   884.62733349
   603.32153256   168.97709477    77.12340077  -508.63128479
  -625.29251825   766.82152571  -673.68720847  -918.50303726
 -1068.07499745  -862.34434175  1235.89471154  1352.56488234
   750.11272628  -954.1441