In [1]:
import numpy as np
import os

from sklearn.preprocessing import StandardScaler
from sklearn.svm import SVC
from sklearn.model_selection import GridSearchCV

import joblib

In [2]:
DATA_PATH = '../data'
IMAGE_SIZE = 784  # 28 * 28

In [3]:
def load_mnist(path, kind='train'):
    """
    'train-images-idx3-ubyte.gz', 'train-labels-idx1-ubyte.gz',
    't10k-images-idx3-ubyte.gz', 't10k-labels-idx1-ubyte.gz'
    使用前需要把上面四个文件下载到 `path` 目录下并解压
    """
    labels_path = os.path.join(path, '%s-labels.idx1-ubyte' % kind)
    images_path = os.path.join(path, '%s-images.idx3-ubyte' % kind)

    with open(labels_path, 'rb') as label_file:
        labels = np.frombuffer(label_file.read(), dtype=np.uint8, offset=8)

    with open(images_path, 'rb') as image_file:
        images = np.frombuffer(image_file.read(), dtype=np.uint8,
                               offset=16).reshape(len(labels), IMAGE_SIZE)

    return images, labels

In [4]:
# 读取 MNIST 数据集
x_train, y_train = load_mnist(DATA_PATH, kind='train')
# x_test, y_test = load_mnist(DATA_PATH, kind='t10k')

In [5]:
# 特征工程：标准化
transfer1 = StandardScaler()
x_train = transfer1.fit_transform(x_train)

In [6]:
# x_train = x_train[:1000]
# y_train = y_train[:1000]

In [7]:
# SVM 分类器
svm_model1 = SVC(probability=True, max_iter=1000)

In [8]:
# 网格搜索与交叉验证
param_dict = {
    'C': [0.1, 1, 10],
    'kernel': ['linear', 'rbf', 'poly'],
    'gamma': ['scale', 'auto']
}
svm_model1 = GridSearchCV(svm_model1, param_dict, n_jobs=-1, cv=2)

In [9]:
# 训练模型
svm_model1.fit(x_train, y_train)



In [10]:
# 在训练数据上超参数调优的结果
print("最佳参数: \n", svm_model1.best_params_)
print("最佳结果（在验证集中的结果）: \n", svm_model1.best_score_)
print("最佳估计器: \n", svm_model1.best_estimator_)
print("交叉验证结果: \n", svm_model1.cv_results_)

最佳参数: 
 {'C': 10, 'gamma': 'scale', 'kernel': 'poly'}
最佳结果（在验证集中的结果）: 
 0.9689666666666666
最佳估计器: 
 SVC(C=10, kernel='poly', max_iter=1000, probability=True)
交叉验证结果: 
 {'mean_fit_time': array([ 1397.87119842,  6699.38095236, 11146.85924792,  1399.16809976,
        6461.12271762, 10606.0184058 ,  1337.83427143,  3630.13734603,
        8024.03268015,  1799.5630163 ,  3899.3222276 ,  8144.07500541,
        1731.6649164 ,  3805.22069693,  4452.2037009 ,  1661.63770986,
        3139.4716953 ,  3792.1876626 ]), 'std_fit_time': array([24.6229949 ,  1.86749935,  0.96375847, 51.00827062,  6.08587241,
        5.6181339 , 20.27499557, 20.02500916, 57.14050519, 21.01801646,
       18.20079052,  7.24398792, 16.87793124, 20.09451246,  8.58549678,
       21.70051408, 19.19152749, 11.68748569]), 'mean_score_time': array([ 460.87713873, 1952.91423678, 1131.55250514,  460.35963678,
       1884.90881681, 1011.31598747,  468.71635067,  894.51209092,
        867.19800854,  606.85464501,  844.11699224,  382

In [11]:
# 保存模型
joblib.dump(svm_model1, '../models/svm_model1.pkl')
# 保存StandardScaler
joblib.dump(transfer1, '../models/transfer1.pkl')

['transfer2.pkl']