# 1. Create data:

In [2]:
from sklearn.datasets import make_classification, make_regression
Xcls, ycls = make_classification(n_samples=200, n_features=20, n_classes=3, random_state=42, n_clusters_per_class=3, n_informative=5)
Xreg, yreg = make_regression(n_samples=200, n_features=20, random_state=42)

# 2. Create and Fit RandomForest and GradientBoosting models for Regression and Classification

In [3]:
import ntree_tuning as ntt

rf_cls = ntt.Ntree_RandForest_Classifier(n_estimators=100)
rf_cls.fit(Xcls, ycls)

rf_reg = ntt.Ntree_RandForest_Regressor(n_estimators=100)
rf_reg.fit(Xreg, yreg)

gb_cls = ntt.Ntree_GradBoost_Classifier(n_estimators=100, subsample=0.8)
gb_cls.fit(Xcls, ycls)

gb_reg = ntt.Ntree_GradBoost_Regressor(n_estimators=100, subsample=0.8)
gb_reg.fit(Xreg, yreg)

# 3. Tune ntree

You then can call the `tune_ntrees` method to get a dictionary of the pairs of the `ntrees` value and the oob-error.

In [5]:
# Gradient Boosting
print(gb_reg.tune_ntrees())
print(gb_cls.tune_ntrees())


# Random Forests
min_trees = 20
max_trees = 80
delta_trees = 5

print(rf_reg.tune_ntrees(Xreg, yreg, min_trees, max_trees, delta_trees))
print(rf_cls.tune_ntrees(Xcls, ycls, min_trees, max_trees, delta_trees))

{0: 22532.43275277552, 1: 25565.17055810882, 2: 14520.452742106205, 3: 13331.595424590934, 4: 12441.05454844621, 5: 7735.164050911012, 6: 10254.532774084528, 7: 13086.274272135677, 8: 6539.369841316826, 9: 11141.744694514015, 10: 5963.472062837947, 11: 8223.593864573446, 12: 5094.72413165307, 13: 4291.2229512030735, 14: 4460.3437062526455, 15: 6443.818216777039, 16: 4027.9336694113226, 17: 4129.772231626123, 18: 3693.0223087386485, 19: 4299.950100241993, 20: 3306.310039427347, 21: 2351.979608096933, 22: 3365.912272434302, 23: 2023.1421107198403, 24: 1744.8926498541427, 25: 1985.025567776605, 26: 1505.6828817487315, 27: 1879.2708209800862, 28: 1364.6889380566558, 29: 1450.889375882002, 30: 1246.9524214184573, 31: 815.9757900578761, 32: 1541.6012266515593, 33: 1183.2609431381316, 34: 1382.544895358118, 35: 1048.713261763626, 36: 1447.8861758664773, 37: 1114.5604285498284, 38: 627.3869826607217, 39: 817.3000760170519, 40: 864.7485745590268, 41: 1134.1369782220918, 42: 798.4732000747089, 4

# 4. Predict with ntrees

In [None]:
print(gb_reg.predict_ntrees(Xreg, ntrees=10))