# 1. Create data:

In [1]:
from sklearn.datasets import make_classification, make_regression
Xcls, ycls = make_classification(n_samples=200, n_features=20, n_classes=3, random_state=42, n_clusters_per_class=3, n_informative=5)
Xreg, yreg = make_regression(n_samples=200, n_features=20, random_state=42)

# 2. Create and Fit RandomForest and GradientBoosting models for Regression and Classification

In [2]:
import ntree_tuning as ntt

rf_cls = ntt.Ntree_RF_Classifier(n_estimators=100)
rf_cls.fit(Xcls, ycls)

rf_reg = ntt.Ntree_RF_Regressor(n_estimators=100)
rf_reg.fit(Xreg, yreg)

gb_cls = ntt.Ntree_GB_Classifier(n_estimators=100, subsample=0.8)
gb_cls.fit(Xcls, ycls)

gb_reg = ntt.Ntree_GB_Regressor(n_estimators=100, subsample=0.8)
gb_reg.fit(Xreg, yreg)

# 3. Tune ntree

Depending on whether you have a RF or GB model you can now call tune_ntree_rf or tune_ntree_gb to get the oob_error dicts for the specified values of ntrees

In [3]:
# Gradient Boosting
print(ntt.tune_ntree_gb(gb_reg))
print(ntt.tune_ntree_gb(gb_cls))


# Random Forests

min_trees = 20
max_trees = 80
delta_trees = 5

print(ntt.tune_ntree_rf(rf_reg, Xreg, yreg, min_trees, max_trees, delta_trees))
print(ntt.tune_ntree_rf(rf_cls, Xcls, ycls, min_trees, max_trees, delta_trees))

{0: 19241.846414301435, 1: 13686.583163185049, 2: 18273.22712049091, 3: 16259.661023667102, 4: 13350.104345503243, 5: 13778.676825829734, 6: 8936.64282581884, 7: 9879.574239527057, 8: 7585.820368615811, 9: 6608.072923623362, 10: 7838.802699143571, 11: 7588.95779732347, 12: 6299.83111284095, 13: 6952.335540872297, 14: 4933.3776727784025, 15: 3934.365890343285, 16: 6040.505867563379, 17: 5366.299382054883, 18: 2233.7970399142655, 19: 2713.6321349916684, 20: 2926.093963854921, 21: 2318.1775173067495, 22: 4285.525719730659, 23: 1851.1592782848688, 24: 2753.6813117483275, 25: 1662.2786319856743, 26: 2693.366993881803, 27: 1618.0929135910612, 28: 1344.4929385778, 29: 1725.1314076499616, 30: 2134.516607997041, 31: 1285.2895096915113, 32: 1052.156492613658, 33: 1923.171729997312, 34: 1240.0481924544013, 35: 1308.1522113898318, 36: 1454.3979272537163, 37: 872.7222994227038, 38: 1292.607581511681, 39: 758.4419990572627, 40: 739.8208465683282, 41: 751.3394594112896, 42: 971.8747560864437, 43: 104