In [None]:
import dautil as dl
from sklearn.grid_search import GridSearchCV
from sklearn.ensemble import RandomForestClassifier
import ch9util
import numpy as np
from IPython.display import HTML
import warnings

warnings.filterwarnings(action='ignore', category=FutureWarning)

In [None]:
X_train, X_test, y_train, y_test = ch9util.rain_split()
clf = RandomForestClassifier(random_state=44)
params = {
    'max_depth': [2,  4],
    'min_samples_leaf': [1, 3],
    'criterion': ['gini', 'entropy'],
    'n_estimators': [100, 200]
}

rfc = GridSearchCV(estimator=RandomForestClassifier(),
                   param_grid=params, cv=5, n_jobs=-1)
rfc.fit(X_train, y_train)
preds = rfc.predict(X_test)
ch9util.npy_save('rfc', preds)

In [None]:
%matplotlib inline
context = dl.nb.Context('random_forest')
dl.nb.RcWidget(context)

In [None]:
sp = dl.plotting.Subplotter(2, 2, context)
html = ch9util.report_rain(preds, y_test, rfc.best_params_, sp.ax)

ntrees = 2 ** np.arange(9)
ch9util.plot_validation(sp.next_ax(), rfc.best_estimator_, 
                        X_train, y_train, 'n_estimators', ntrees)

depths = np.arange(2, 9)
ch9util.plot_validation(sp.next_ax(), rfc.best_estimator_, 
                        X_train, y_train, 'max_depth', depths)

ch9util.plot_learn_curve(sp.next_ax(), 
                         rfc.best_estimator_, X_train,y_train)
HTML(html + sp.exit())