# 作業
請使用不同的資料集，並使用 hyper-parameter search 的方式，看能不能找出最佳的超參數組合

In [12]:
from sklearn import datasets, metrics
digits = datasets.load_digits()
from sklearn.ensemble import GradientBoostingClassifier
from sklearn.model_selection import train_test_split

x_train, x_test, y_train, y_test = train_test_split(digits.data, digits.target, stratify=digits.target, test_size=0.25, random_state=4)

clf = GradientBoostingClassifier()

clf.fit(x_train, y_train)

y_pred = clf.predict(x_test)

acc = metrics.accuracy_score(y_test, y_pred)
print("Acuuracy: ", acc)


print("mse :"+str(metrics.mean_squared_error(y_test, y_pred)))

Acuuracy:  0.9688888888888889
mse :0.5511111111111111


In [23]:
from sklearn.model_selection import train_test_split, KFold, GridSearchCV
from sklearn.ensemble import GradientBoostingRegressor

# 設定要訓練的超參數組合
n_estimators = [50, 75, 100, 125, 150,175,200]
max_depth = [1,  3, 5,10,15]
param_grid = dict(n_estimators=n_estimators, max_depth=max_depth)

## 建立搜尋物件，放入模型及參數組合字典 (n_jobs=-1 會使用全部 cpu 平行運算)
grid_search = GridSearchCV(clf, param_grid, scoring="neg_mean_squared_error", n_jobs=-1, verbose=1)

# 開始搜尋最佳參數
grid_result = grid_search.fit(x_train, y_train)

# 預設會跑 3-fold cross-validadtion，總共 9 種參數組合，總共要 train 27 次模型

Fitting 3 folds for each of 35 candidates, totalling 105 fits


[Parallel(n_jobs=-1)]: Done  34 tasks      | elapsed:   13.6s
[Parallel(n_jobs=-1)]: Done 105 out of 105 | elapsed:   42.3s finished


In [24]:
# 印出最佳結果與最佳參數
print("Best Accuracy: %f using %s" % (grid_result.best_score_, grid_result.best_params_))

Best Accuracy: -0.738679 using {'max_depth': 3, 'n_estimators': 125}


In [25]:
# 使用最佳參數重新建立模型
clf_bestparam = GradientBoostingClassifier(max_depth=grid_result.best_params_['max_depth'],
                                           n_estimators=grid_result.best_params_['n_estimators'])

# 訓練模型
clf_bestparam.fit(x_train, y_train)

# 預測測試集
y_pred = clf_bestparam.predict(x_test)

In [26]:
print(metrics.mean_squared_error(y_test, y_pred))
acc = metrics.accuracy_score(y_test, y_pred)
print("Acuuracy: ", acc)

0.6222222222222222
Acuuracy:  0.9688888888888889
