## [作業重點]
了解如何使用 Sklearn 中的 hyper-parameter search 找出最佳的超參數

### 作業
請使用不同的資料集，並使用 hyper-parameter search 的方式，看能不能找出最佳的超參數組合

In [6]:
from sklearn import datasets
from sklearn.metrics import mean_squared_error
from sklearn.model_selection import train_test_split
from sklearn.model_selection import GridSearchCV
from sklearn.ensemble import GradientBoostingRegressor

In [2]:
boston = datasets.load_boston()

In [3]:
boston.data.shape

(506, 13)

In [4]:
boston.target.shape

(506,)

In [5]:
# 切分資料
x_train, x_test, y_train, y_test = train_test_split(boston.data, boston.target, test_size=50, random_state=0)

In [7]:
# 建立調參前Baseline對照
reg = GradientBoostingRegressor()

reg.fit(x_train, y_train)
y_pred = reg.predict(x_test)

print('baseline mse: ', mean_squared_error(y_test, y_pred))

baseline mse:  21.81634081547697


In [13]:
# 設定要訓練的超參數組合
n_estimators = [50, 100, 150, 200, 250, 300]
max_depth = [1, 2, 3, 4, 5, 6]
param_grid = dict(n_estimators=n_estimators, max_depth=max_depth)

## 建立搜尋物件，放入模型及參數組合字典 (n_jobs=-1 會使用全部 cpu 平行運算)
grid_search = GridSearchCV(reg, param_grid, scoring="neg_mean_squared_error", n_jobs=-1, verbose=1)

# 開始搜尋最佳參數
grid_result = grid_search.fit(x_train, y_train)

Fitting 3 folds for each of 36 candidates, totalling 108 fits


[Parallel(n_jobs=-1)]: Using backend LokyBackend with 8 concurrent workers.
[Parallel(n_jobs=-1)]: Done  34 tasks      | elapsed:    3.2s
[Parallel(n_jobs=-1)]: Done 108 out of 108 | elapsed:    5.4s finished


In [14]:
print('best result: {} with params= {}'.format(grid_result.best_score_, grid_result.best_params_))

best result: -10.973220860167562 with params= {'max_depth': 3, 'n_estimators': 250}


In [15]:
reg_best_params = GradientBoostingRegressor(max_depth=grid_result.best_params_['max_depth'],
                                            n_estimators=grid_result.best_params_['n_estimators'])

reg_best_params.fit(x_train, y_train)
y_pred_best_params = reg_best_params.predict(x_test)

print('mse with best params: ', mean_squared_error(y_test, y_pred_best_params))

mse with best params:  19.740132452164108
