In [19]:
import numpy as np
import os
import matplotlib
import matplotlib.pyplot as plt

# 一些基础参数配置
plt.rcParams['axes.labelsize'] = 14
plt.rcParams['xtick.labelsize'] = 12
plt.rcParams['ytick.labelsize'] = 12
import warnings
warnings.filterwarnings("ignore")
np.random.seed(42)

In [20]:
from sklearn.datasets import fetch_openml

# https://scikit-learn.org/stable/modules/generated/sklearn.datasets.fetch_openml.html
mnist = fetch_openml('mnist_784', version=1) # 通过 name/dataset-id 从OpenML官网下载数据集
mnist["data"], mnist["target"]

(       pixel1  pixel2  pixel3  pixel4  pixel5  pixel6  pixel7  pixel8  pixel9  \
 0           0       0       0       0       0       0       0       0       0   
 1           0       0       0       0       0       0       0       0       0   
 2           0       0       0       0       0       0       0       0       0   
 3           0       0       0       0       0       0       0       0       0   
 4           0       0       0       0       0       0       0       0       0   
 ...       ...     ...     ...     ...     ...     ...     ...     ...     ...   
 69995       0       0       0       0       0       0       0       0       0   
 69996       0       0       0       0       0       0       0       0       0   
 69997       0       0       0       0       0       0       0       0       0   
 69998       0       0       0       0       0       0       0       0       0   
 69999       0       0       0       0       0       0       0       0       0   
 
        pixel1

In [21]:
# X：一个包含 784 个特征的矩阵
# y：一个标签向量，表示每个图像的数字标签（0 到 9）。
X, y = mnist["data"], mnist["target"]
X.shape, y.shape

((70000, 784), (70000,))

In [22]:
# 取前60000个为训练集，后60000~70000为测试集
X_train, X_test, y_train, y_test = X[:60000], X[60000:], y[:60000], y[60000:]

In [23]:
# 洗牌操作
import numpy as np
shuffle_index = np.random.permutation(len(X_train))
X_train, y_train = X_train.iloc[shuffle_index], y_train.iloc[shuffle_index]

In [24]:
shuffle_index

array([12628, 37730, 39991, ...,   860, 15795, 56422])

In [25]:
y_train_5 = (y_train == "5")
y_test_5 = (y_test == "5")

In [26]:
y_train_5[:10]

12628    False
37730    False
39991    False
8525     False
8279     False
51012    False
14871    False
15127    False
9366      True
33322    False
Name: class, dtype: bool

In [27]:
from sklearn.linear_model import SGDClassifier
sdg_clf = SGDClassifier(max_iter=5, random_state=42) # 创建一个随机梯度下降算法 训练的分类器
sdg_clf.fit(X_train, y_train_5) # 拟合数据

In [28]:
sdg_clf.predict([X.loc[35000]]) # 根据标签定位 X 的元素

array([False])

In [29]:
y[35000]

'1'

In [30]:
# https://scikit-learn.org/stable/modules/generated/sklearn.model_selection.cross_val_score.html
from sklearn.model_selection import cross_val_score
"""
sdg_clf：它是一个分类器，用于进行训练和预测。

cv=3：指定交叉验证的折数（folds）

scoring='accuracy'：指定评估指标，这里使用的是 accuracy，即准确率，来衡量模型的预测性能。
"""
cross_val_score(sdg_clf, X_train, y_train_5, cv=3, scoring='accuracy')

array([0.964 , 0.9579, 0.9571])

In [31]:
from sklearn.model_selection import StratifiedKFold
from sklearn.base import clone

"""
较为复杂，灵活性高
"""
skfolds = StratifiedKFold(n_splits=3, random_state=42,shuffle=True)

for train_index, test_index in skfolds.split(X_train,y_train_5):
    clone_clf = clone(sdg_clf)
    X_train_folds = X_train.iloc[train_index]
    y_train_folds = y_train_5.iloc[train_index]
    X_test_folds = X_train.iloc[test_index]
    y_test_folds = y_train_5.iloc[test_index]
    
    clone_clf.fit(X_train_folds, y_train_folds)
    y_pred = clone_clf.predict(X_test_folds)
    n_correct = sum(y_pred == y_test_folds)
    
    print(n_correct / len(y_pred))

0.963
0.9455
0.95255


## 混淆矩阵

In [32]:
from sklearn.model_selection import cross_val_predict
y_train_pred = cross_val_predict(sdg_clf,X_train,y_train,cv=3)

In [33]:
y_train_pred.shape

(60000,)

In [34]:
X_train.shape

(60000, 784)

In [35]:
from sklearn.metrics import confusion_matrix
y_train_pred = [int(label) for label in y_train_pred]
confusion_matrix(y_train_5,y_train_pred)

array([[6017, 6789, 5483, 6398, 6854,  628, 5551, 7151, 5262, 4446],
       [  90,   36,   44,  361,  292, 3862,  108,   65,  464,   99],
       [   0,    0,    0,    0,    0,    0,    0,    0,    0,    0],
       [   0,    0,    0,    0,    0,    0,    0,    0,    0,    0],
       [   0,    0,    0,    0,    0,    0,    0,    0,    0,    0],
       [   0,    0,    0,    0,    0,    0,    0,    0,    0,    0],
       [   0,    0,    0,    0,    0,    0,    0,    0,    0,    0],
       [   0,    0,    0,    0,    0,    0,    0,    0,    0,    0],
       [   0,    0,    0,    0,    0,    0,    0,    0,    0,    0],
       [   0,    0,    0,    0,    0,    0,    0,    0,    0,    0]],
      dtype=int64)