# 网格搜索

In [17]:
from sklearn.datasets import load_digits
from sklearn.model_selection import GridSearchCV
from sklearn.neighbors import KNeighborsClassifier
from sklearn.model_selection import train_test_split
from sklearn.neighbors import KNeighborsClassifier
from sklearn.metrics import accuracy_score

In [28]:
para_grid = [
    {
    "weights":["uniform"],
    "n_neighbors":[k for k in range(1,11)]
    },
    {
        "weights":["distance"],
        "n_neighbors":[k for k in range(1,11)],
        "p":[i for i in range(1,5)]
    }
]

### 1.设置超参数列表

In [51]:
para_grid

[{'n_neighbors': [1, 2, 3, 4, 5, 6, 7, 8, 9, 10], 'weights': ['uniform']},
 {'n_neighbors': [1, 2, 3, 4, 5, 6, 7, 8, 9, 10],
  'p': [1, 2, 3, 4],
  'weights': ['distance']}]

In [52]:
#加载数据
digits_data = load_digits()
data_x = digits_data.data
data_y = digits_data.target
#分割数据
train_x,test_x,train_y,test_y=train_test_split(data_x,data_y,test_size=0.3,random_state=566)
#创建模型
knn_cla = KNeighborsClassifier()

### 2.通过图搜索寻找超参数
#### 可选参数 n_jobs默认为1 将使用n核进行并行处理（多核） 传入-1代表使用所有处理器核心
#### 可选参数 verbose   设置为2 可以即时显示信息

In [61]:
grid_search = GridSearchCV(knn_cla,param_grid=para_grid,n_jobs=1,verbose=2)

### fit最佳参数

In [62]:
%%time
grid_search.fit(train_x,train_y)

Fitting 3 folds for each of 50 candidates, totalling 150 fits
[CV] n_neighbors=1, weights=uniform ..................................
[CV] ................... n_neighbors=1, weights=uniform, total=   0.0s
[CV] n_neighbors=1, weights=uniform ..................................
[CV] ................... n_neighbors=1, weights=uniform, total=   0.0s
[CV] n_neighbors=1, weights=uniform ..................................
[CV] ................... n_neighbors=1, weights=uniform, total=   0.0s
[CV] n_neighbors=2, weights=uniform ..................................


[Parallel(n_jobs=1)]: Done   1 out of   1 | elapsed:    0.0s remaining:    0.0s


[CV] ................... n_neighbors=2, weights=uniform, total=   0.0s
[CV] n_neighbors=2, weights=uniform ..................................
[CV] ................... n_neighbors=2, weights=uniform, total=   0.0s
[CV] n_neighbors=2, weights=uniform ..................................
[CV] ................... n_neighbors=2, weights=uniform, total=   0.0s
[CV] n_neighbors=3, weights=uniform ..................................
[CV] ................... n_neighbors=3, weights=uniform, total=   0.0s
[CV] n_neighbors=3, weights=uniform ..................................
[CV] ................... n_neighbors=3, weights=uniform, total=   0.0s
[CV] n_neighbors=3, weights=uniform ..................................
[CV] ................... n_neighbors=3, weights=uniform, total=   0.0s
[CV] n_neighbors=4, weights=uniform ..................................
[CV] ................... n_neighbors=4, weights=uniform, total=   0.0s
[CV] n_neighbors=4, weights=uniform ..................................
[CV] .

[CV] ............. n_neighbors=3, p=3, weights=distance, total=   0.6s
[CV] n_neighbors=3, p=3, weights=distance ............................
[CV] ............. n_neighbors=3, p=3, weights=distance, total=   0.6s
[CV] n_neighbors=3, p=4, weights=distance ............................
[CV] ............. n_neighbors=3, p=4, weights=distance, total=   0.5s
[CV] n_neighbors=3, p=4, weights=distance ............................
[CV] ............. n_neighbors=3, p=4, weights=distance, total=   0.5s
[CV] n_neighbors=3, p=4, weights=distance ............................
[CV] ............. n_neighbors=3, p=4, weights=distance, total=   0.6s
[CV] n_neighbors=4, p=1, weights=distance ............................
[CV] ............. n_neighbors=4, p=1, weights=distance, total=   0.0s
[CV] n_neighbors=4, p=1, weights=distance ............................
[CV] ............. n_neighbors=4, p=1, weights=distance, total=   0.0s
[CV] n_neighbors=4, p=1, weights=distance ............................
[CV] .

[CV] ............. n_neighbors=8, p=2, weights=distance, total=   0.0s
[CV] n_neighbors=8, p=3, weights=distance ............................
[CV] ............. n_neighbors=8, p=3, weights=distance, total=   0.5s
[CV] n_neighbors=8, p=3, weights=distance ............................
[CV] ............. n_neighbors=8, p=3, weights=distance, total=   0.5s
[CV] n_neighbors=8, p=3, weights=distance ............................
[CV] ............. n_neighbors=8, p=3, weights=distance, total=   0.5s
[CV] n_neighbors=8, p=4, weights=distance ............................
[CV] ............. n_neighbors=8, p=4, weights=distance, total=   0.5s
[CV] n_neighbors=8, p=4, weights=distance ............................
[CV] ............. n_neighbors=8, p=4, weights=distance, total=   0.5s
[CV] n_neighbors=8, p=4, weights=distance ............................
[CV] ............. n_neighbors=8, p=4, weights=distance, total=   0.5s
[CV] n_neighbors=9, p=1, weights=distance ............................
[CV] .

[Parallel(n_jobs=1)]: Done 150 out of 150 | elapsed:  2.1min finished


GridSearchCV(cv=None, error_score='raise',
       estimator=KNeighborsClassifier(algorithm='auto', leaf_size=30, metric='minkowski',
           metric_params=None, n_jobs=1, n_neighbors=5, p=2,
           weights='uniform'),
       fit_params=None, iid=True, n_jobs=1,
       param_grid=[{'weights': ['uniform'], 'n_neighbors': [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]}, {'weights': ['distance'], 'n_neighbors': [1, 2, 3, 4, 5, 6, 7, 8, 9, 10], 'p': [1, 2, 3, 4]}],
       pre_dispatch='2*n_jobs', refit=True, return_train_score='warn',
       scoring=None, verbose=2)

## 3.结果

### 最佳超参数

In [45]:
grid_search.best_params_

{'n_neighbors': 4, 'p': 3, 'weights': 'distance'}

### 最佳score

In [46]:
grid_search.best_score_

0.98249801113762925

### 最佳分类器

In [47]:
knn_clf = grid_search.best_estimator_

In [48]:
knn_clf.get_params

<bound method BaseEstimator.get_params of KNeighborsClassifier(algorithm='auto', leaf_size=30, metric='minkowski',
           metric_params=None, n_jobs=1, n_neighbors=4, p=3,
           weights='distance')>

In [49]:
knn_clf.score(test_x,test_y)

0.98518518518518516