# 1、加载数据集

In [34]:
import numpy as np
from sklearn.model_selection import train_test_split
from sklearn.datasets import load_iris

In [35]:
iris = load_iris()
x = iris.data
y = iris.target
x_train, x_test, y_train, y_test = train_test_split(x, y, test_size=0.7, random_state=233, stratify=y)

# 2、超参数

In [36]:
from sklearn.neighbors import KNeighborsClassifier

In [37]:
neigh = KNeighborsClassifier(
    n_neighbors=3, # 最近邻个数
    weights='distance', # 权重
    p=2 # 距离度量
)

neigh.fit(x_train, y_train)
print(neigh.score(x_test, y_test))

0.9428571428571428


- 参数择优

In [38]:
best_score = -1
best_n = 1
best_weight = ''
best_p = -1

for n in range(1, 20):
    for weight in ['uniform', 'distance']:
        for p in range(1, 7):
            neigh = KNeighborsClassifier(
                n_neighbors=n,
                weights=weight,
                p=p
            )
            neigh.fit(x_train, y_train)
            score = neigh.score(x_test, y_test)
            if score > best_score:
                best_score = score
                best_n = n
                best_weight = weight
                best_p = p

print(best_score, '/', best_n, '/', best_weight, '/', best_p)

0.9619047619047619 / 1 / uniform / 1


- sklearn搜索

In [39]:
from sklearn.model_selection import GridSearchCV

In [None]:
params = {
    'n_neighbors': np.arange(1, 20),
    'weights': ['uniform', 'distance'],
    'p': np.arange(1, 7)
}

grid = GridSearchCV(
    estimator=KNeighborsClassifier(), 
    param_grid=params, 
    n_jobs=-1
)

grid.fit(x_train, y_train)
print(grid.best_score_)
print('-'*100)
print(grid.best_params_)
print('-'*100)
print(grid.best_estimator_.predict(x_test))
print('-'*100)
print(grid.best_estimator_.score(x_test, y_test))

1.0
----------------------------------------------------------------------------------------------------
{'n_neighbors': 1, 'p': 2, 'weights': 'uniform'}
----------------------------------------------------------------------------------------------------
[0 2 0 0 2 1 1 0 0 1 2 1 0 0 1 1 0 1 1 1 1 0 0 1 1 1 1 2 0 0 0 1 2 2 1 0 0
 0 0 0 0 1 2 1 2 1 2 1 0 1 0 2 1 0 1 2 1 1 2 0 2 0 0 1 2 2 0 1 2 0 2 2 1 2
 0 2 2 1 1 2 2 1 2 1 0 2 2 2 1 1 1 2 1 0 0 0 1 0 2 0 1 2 2 0 2]
----------------------------------------------------------------------------------------------------
0.9523809523809523
