目标：KNN算法计算鸢尾花数据集并分类

流程：
1. 导入鸢尾花数据集
2. 分割训练集与测试集
3. 进行数据标准化
4. 模型训练与预测
5. 模型选择与调优

In [64]:
import pandas as pd
import numpy as np

from sklearn.datasets import load_iris

导入数据集

In [65]:
iris = load_iris()
df = pd.DataFrame(iris.data, columns=iris.feature_names)

分割测试集与训练集

In [66]:
from sklearn.model_selection import train_test_split

x_train, x_test, y_train, y_test = train_test_split(iris.data, iris.target, random_state=6)

数据标准化

In [67]:
from sklearn.preprocessing import StandardScaler

transfer = StandardScaler()
x_train = transfer.fit_transform(x_train)
x_test = transfer.transform(x_test)

模型训练与预测

In [68]:
from sklearn.neighbors import KNeighborsClassifier

estimator = KNeighborsClassifier()

网格搜索与验证

In [69]:
from sklearn.model_selection import GridSearchCV

param_dict = {"n_neighbors" : [3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14]}
estimator = GridSearchCV(estimator, param_grid=param_dict, cv=10)

调优后进行测试输出

In [70]:
estimator.fit(x_train, y_train)

score = estimator.score(x_test, y_test)
print("准确率为:\n", score)

准确率为:
 0.9473684210526315


查看最佳参数以及最佳结果和最佳估计器

In [71]:
print("最佳参数:\n", estimator.best_params_)
print("最佳结果:\n", estimator.best_score_)
print("最佳估计器:\n", estimator.best_estimator_)
print("交叉验证结果:\n", estimator.cv_results_)

最佳参数:
 {'n_neighbors': 11}
最佳结果:
 0.9734848484848484
最佳估计器:
 KNeighborsClassifier(n_neighbors=11)
交叉验证结果:
 {'mean_fit_time': array([0.00052993, 0.00052366, 0.00044312, 0.00044446, 0.00049322,
       0.00047696, 0.00046227, 0.00044487, 0.00043921, 0.00045037,
       0.0004472 , 0.00045278]), 'std_fit_time': array([1.56272168e-04, 1.54295884e-04, 2.56828938e-06, 4.55597880e-06,
       2.37819515e-05, 4.31410167e-05, 2.80501156e-05, 7.35044905e-06,
       3.82035551e-06, 2.33133874e-05, 1.51406310e-05, 3.02516350e-05]), 'mean_score_time': array([0.00103459, 0.00103326, 0.00091236, 0.00091295, 0.00103979,
       0.00097516, 0.00096185, 0.00092506, 0.00091116, 0.00094414,
       0.00091939, 0.00102096]), 'std_score_time': array([2.65391584e-04, 1.81379911e-04, 8.55457690e-06, 1.01319762e-05,
       6.35883224e-05, 1.03138480e-04, 8.94723049e-05, 4.19078558e-05,
       1.03705505e-05, 6.34500780e-05, 1.68144439e-05, 2.47672952e-04]), 'param_n_neighbors': masked_array(data=[3, 4, 5, 6, 7, 8, 