In [11]:
from tensorflow.keras.datasets.mnist import load_data
import numpy as np
from matplotlib import pyplot as plt
from sklearn.model_selection import train_test_split
from sklearn.neighbors import KNeighborsClassifier
from sklearn.metrics import classification_report
import time
from tqdm import tqdm
%matplotlib inline

(X_train, Y_train), (X_test, Y_test) = load_data()
print("X_train: ", len(X_train), "件")
print("Y_train: ", len(Y_train), "件")
print("X_test : ", len(X_test), "件")
print("Y_test : ", len(Y_test), "件")

X_train:  60000 件
Y_train:  60000 件
X_test :  10000 件
Y_test :  10000 件


In [12]:
# 読み込み直後のデータ形式
print("X_train.shape: ", X_train.shape)
print("X_test.shape : ", X_test.shape)
print("Y_train.shape: ", Y_train.shape)
print("Y_test.shape : ", Y_test.shape)

X_train.shape:  (60000, 28, 28)
X_test.shape :  (10000, 28, 28)
Y_train.shape:  (60000,)
Y_test.shape :  (10000,)


In [13]:
# 画像データの変換
X_train = X_train.reshape(60000, 28 * 28)
X_test = X_test.reshape(10000, 28 * 28)
print("変換後X_train: ", X_train.shape)
print("変換後X_test : ", X_test.shape)

変換後X_train:  (60000, 784)
変換後X_test :  (10000, 784)


In [14]:
# parameters-------------------------------------
k_list = [3, 5, 7, 9, 11, 13, 15] # knnのkの値を設定
test_size = 0.01 # 検証データ
random_state = 60 # train_test_split()のrandom_stateの引数
# -----------------------------------------------

# 全60,000件の学習データのうち，test_sizeで指定した割合のデータを検証用データにとっておく
trData, valData, trLabels, valLabels = train_test_split(np.array(X_train), Y_train, test_size=test_size, random_state=random_state)

# 何%を検証用にするか
nVals = np.arange(0.9, 0.0, -0.1)

accuracies = []
times = []
pridicts = []

for k in k_list:
    accuracy = []  # 正確度保存用
    proc_time = []   # 実行時間保存用
    for per in tqdm( nVals ):  # tqdmを使うとき
        # 48000件のデータのうち，per％をテスト用に，残りを学習用として確保
        trainData, testData, trainLabels, testLabels = train_test_split(trData, trLabels, test_size=per, random_state=42)
        start = time.time()  # 時間計測開始

        # NNモデル
        model = KNeighborsClassifier(n_neighbors=k)
        model.fit(trainData, trainLabels)

        # 検証用データを使って実行
        score = model.score(valData, valLabels)    

        duration = time.time() - start  # 計測終了

        print("train size=%d, accuracy=%.2f%%, time=%.2f[s]" % (len(trainLabels), score * 100, duration))
        # 正確度と処理時間を保存
        accuracy.append(score)
        proc_time.append(duration)
        
        # モデルの予想
        val_predict = model.predict(valData)
        
        wrong_list = []
        wrong_value = []
        correct_value = []
        for i in range(len(val_predict)):
            if val_predict[i] != valLabels[i]:
                wrong_list.append(i)
                wrong_value.append(val_predict[i])
                correct_value.append(valLabels[i])
        plt.figure(figsize=(14,10))
        for i in range(len(correct_value)):
            ax = plt.subplot(3, 6, i + 1)
            ax.set_axis_off()
            ax.set_title( "correct: {}, wrong:{}".format(correct_value[i], wrong_value[i]))
            plt.imshow(valData[wrong_list[i]].reshape(28, 28).astype("uint8"))
        plt.show()

        # 生確度と処理時間の結果を別ファイルに保存
        resultfilename = "k=" + str(k) + "_train" + str((1-test_size)*100) + "%_accuracies.txt"
        with open(resultfilename, mode='w') as f:
            for d in accuracies:
                f.write("%s\n" % d)

        resultfilename = "k=" + str(k) + "_train" + str((1-test_size)*100) + "%_times.txt"
        with open(resultfilename, mode='w') as f:
            for d in proc_time:
                f.write("%s\n" % d)
        
    accuracies.append(accuracy)
    times.append(proc_time)

 11%|█████████▎                                                                          | 1/9 [00:02<00:23,  2.89s/it]

train size=3000, accuracy=91.83%, time=2.87[s]
train size= 2999 achieved highest accuracy of 91.83% on validation data


 22%|██████████████████▋                                                                 | 2/9 [00:08<00:29,  4.25s/it]

train size=6000, accuracy=93.68%, time=5.19[s]
train size= 5999 achieved highest accuracy of 93.68% on validation data


 33%|████████████████████████████                                                        | 3/9 [00:15<00:34,  5.79s/it]

train size=8999, accuracy=94.71%, time=7.61[s]
train size= 8999 achieved highest accuracy of 94.71% on validation data


 44%|█████████████████████████████████████▎                                              | 4/9 [00:25<00:36,  7.37s/it]

train size=11999, accuracy=95.24%, time=9.77[s]
train size=11999 achieved highest accuracy of 95.24% on validation data


 56%|██████████████████████████████████████████████▋                                     | 5/9 [00:37<00:36,  9.21s/it]

train size=14999, accuracy=95.66%, time=12.46[s]
train size=14999 achieved highest accuracy of 95.66% on validation data


 67%|████████████████████████████████████████████████████████                            | 6/9 [00:54<00:34, 11.61s/it]

train size=17999, accuracy=95.98%, time=16.26[s]
train size=17999 achieved highest accuracy of 95.98% on validation data


 78%|█████████████████████████████████████████████████████████████████▎                  | 7/9 [01:12<00:27, 13.94s/it]

train size=20999, accuracy=96.11%, time=18.70[s]
train size=20999 achieved highest accuracy of 96.11% on validation data


 89%|██████████████████████████████████████████████████████████████████████████▋         | 8/9 [01:33<00:16, 16.16s/it]

train size=23999, accuracy=96.37%, time=20.91[s]
train size=23999 achieved highest accuracy of 96.37% on validation data


100%|████████████████████████████████████████████████████████████████████████████████████| 9/9 [01:57<00:00, 13.04s/it]
  0%|                                                                                            | 0/9 [00:00<?, ?it/s]

train size=26999, accuracy=96.51%, time=23.45[s]
train size=26999 achieved highest accuracy of 96.51% on validation data


 11%|█████████▎                                                                          | 1/9 [00:03<00:29,  3.69s/it]

train size=3000, accuracy=91.53%, time=3.67[s]
train size= 2999 achieved highest accuracy of 91.53% on validation data


 22%|██████████████████▋                                                                 | 2/9 [00:10<00:37,  5.36s/it]

train size=6000, accuracy=93.57%, time=6.51[s]
train size= 5999 achieved highest accuracy of 93.57% on validation data


 33%|████████████████████████████                                                        | 3/9 [00:19<00:43,  7.25s/it]

train size=8999, accuracy=94.45%, time=9.48[s]
train size= 8999 achieved highest accuracy of 94.45% on validation data


 44%|█████████████████████████████████████▎                                              | 4/9 [00:31<00:45,  9.02s/it]

train size=11999, accuracy=94.78%, time=11.72[s]
train size=11999 achieved highest accuracy of 94.78% on validation data


 56%|██████████████████████████████████████████████▋                                     | 5/9 [00:47<00:45, 11.42s/it]

train size=14999, accuracy=95.32%, time=15.66[s]
train size=14999 achieved highest accuracy of 95.32% on validation data


 67%|████████████████████████████████████████████████████████                            | 6/9 [01:05<00:41, 13.91s/it]

train size=17999, accuracy=95.67%, time=18.72[s]
train size=17999 achieved highest accuracy of 95.67% on validation data


 78%|█████████████████████████████████████████████████████████████████▎                  | 7/9 [01:27<00:32, 16.42s/it]

train size=20999, accuracy=95.91%, time=21.59[s]
train size=20999 achieved highest accuracy of 95.91% on validation data


 89%|██████████████████████████████████████████████████████████████████████████▋         | 8/9 [01:52<00:19, 19.07s/it]

train size=23999, accuracy=96.16%, time=24.73[s]
train size=23999 achieved highest accuracy of 96.16% on validation data


100%|████████████████████████████████████████████████████████████████████████████████████| 9/9 [02:21<00:00, 15.78s/it]
  0%|                                                                                            | 0/9 [00:00<?, ?it/s]

train size=26999, accuracy=96.29%, time=29.76[s]
train size=26999 achieved highest accuracy of 96.29% on validation data


 11%|█████████▎                                                                          | 1/9 [00:03<00:31,  3.96s/it]

train size=3000, accuracy=91.11%, time=3.94[s]
train size= 2999 achieved highest accuracy of 91.11% on validation data


 22%|██████████████████▋                                                                 | 2/9 [00:10<00:40,  5.76s/it]

train size=6000, accuracy=93.26%, time=7.00[s]
train size= 5999 achieved highest accuracy of 93.26% on validation data


 33%|████████████████████████████                                                        | 3/9 [00:22<00:48,  8.16s/it]

train size=8999, accuracy=94.14%, time=11.01[s]
train size= 8999 achieved highest accuracy of 94.14% on validation data


 44%|█████████████████████████████████████▎                                              | 4/9 [00:36<00:53, 10.76s/it]

train size=11999, accuracy=94.71%, time=14.73[s]
train size=11999 achieved highest accuracy of 94.71% on validation data


 56%|██████████████████████████████████████████████▋                                     | 5/9 [00:53<00:51, 12.81s/it]

train size=14999, accuracy=95.08%, time=16.41[s]
train size=14999 achieved highest accuracy of 95.08% on validation data


 67%|████████████████████████████████████████████████████████                            | 6/9 [01:13<00:46, 15.48s/it]

train size=17999, accuracy=95.48%, time=20.66[s]
train size=17999 achieved highest accuracy of 95.48% on validation data


 78%|█████████████████████████████████████████████████████████████████▎                  | 7/9 [01:36<00:35, 17.92s/it]

train size=20999, accuracy=95.71%, time=22.93[s]
train size=20999 achieved highest accuracy of 95.71% on validation data
train size=23999, accuracy=95.94%, time=24.98[s]


 89%|██████████████████████████████████████████████████████████████████████████▋         | 8/9 [02:02<00:20, 20.25s/it]

train size=23999 achieved highest accuracy of 95.94% on validation data


100%|████████████████████████████████████████████████████████████████████████████████████| 9/9 [02:30<00:00, 16.78s/it]
  0%|                                                                                            | 0/9 [00:00<?, ?it/s]

train size=26999, accuracy=96.10%, time=28.93[s]
train size=26999 achieved highest accuracy of 96.10% on validation data


 11%|█████████▎                                                                          | 1/9 [00:03<00:31,  3.89s/it]

train size=3000, accuracy=90.73%, time=3.86[s]
train size= 2999 achieved highest accuracy of 90.73% on validation data


 22%|██████████████████▋                                                                 | 2/9 [00:10<00:40,  5.73s/it]

train size=6000, accuracy=92.89%, time=7.00[s]
train size= 5999 achieved highest accuracy of 92.89% on validation data


 33%|████████████████████████████                                                        | 3/9 [00:20<00:45,  7.61s/it]

train size=8999, accuracy=93.87%, time=9.83[s]
train size= 8999 achieved highest accuracy of 93.87% on validation data


 44%|█████████████████████████████████████▎                                              | 4/9 [00:33<00:47,  9.49s/it]

train size=11999, accuracy=94.37%, time=12.35[s]
train size=11999 achieved highest accuracy of 94.37% on validation data


 56%|██████████████████████████████████████████████▋                                     | 5/9 [00:50<00:49, 12.30s/it]

train size=14999, accuracy=94.85%, time=17.26[s]
train size=14999 achieved highest accuracy of 94.85% on validation data


 67%|████████████████████████████████████████████████████████                            | 6/9 [01:12<00:46, 15.64s/it]

train size=17999, accuracy=95.26%, time=22.12[s]
train size=17999 achieved highest accuracy of 95.26% on validation data


 78%|█████████████████████████████████████████████████████████████████▎                  | 7/9 [01:38<00:37, 18.90s/it]

train size=20999, accuracy=95.49%, time=25.61[s]
train size=20999 achieved highest accuracy of 95.49% on validation data


 89%|██████████████████████████████████████████████████████████████████████████▋         | 8/9 [02:05<00:21, 21.65s/it]

train size=23999, accuracy=95.73%, time=27.50[s]
train size=23999 achieved highest accuracy of 95.73% on validation data


100%|████████████████████████████████████████████████████████████████████████████████████| 9/9 [02:38<00:00, 17.66s/it]
  0%|                                                                                            | 0/9 [00:00<?, ?it/s]

train size=26999, accuracy=95.91%, time=33.29[s]
train size=26999 achieved highest accuracy of 95.91% on validation data


 11%|█████████▎                                                                          | 1/9 [00:04<00:36,  4.52s/it]

train size=3000, accuracy=90.57%, time=4.50[s]
train size= 2999 achieved highest accuracy of 90.57% on validation data


 22%|██████████████████▋                                                                 | 2/9 [00:12<00:47,  6.77s/it]

train size=6000, accuracy=92.67%, time=8.33[s]
train size= 5999 achieved highest accuracy of 92.67% on validation data


 33%|████████████████████████████                                                        | 3/9 [00:25<00:55,  9.29s/it]

train size=8999, accuracy=93.66%, time=12.27[s]
train size= 8999 achieved highest accuracy of 93.66% on validation data


 44%|█████████████████████████████████████▎                                              | 4/9 [00:41<00:59, 11.98s/it]

train size=11999, accuracy=94.23%, time=16.08[s]
train size=11999 achieved highest accuracy of 94.23% on validation data


 56%|██████████████████████████████████████████████▋                                     | 5/9 [01:02<01:01, 15.28s/it]

train size=14999, accuracy=94.58%, time=21.12[s]
train size=14999 achieved highest accuracy of 94.58% on validation data


 67%|████████████████████████████████████████████████████████                            | 6/9 [01:23<00:52, 17.38s/it]

train size=17999, accuracy=94.99%, time=21.44[s]
train size=17999 achieved highest accuracy of 94.99% on validation data


 78%|█████████████████████████████████████████████████████████████████▎                  | 7/9 [01:49<00:40, 20.00s/it]

train size=20999, accuracy=95.25%, time=25.38[s]
train size=20999 achieved highest accuracy of 95.25% on validation data


 89%|██████████████████████████████████████████████████████████████████████████▋         | 8/9 [02:17<00:22, 22.64s/it]

train size=23999, accuracy=95.58%, time=28.26[s]
train size=23999 achieved highest accuracy of 95.58% on validation data


100%|████████████████████████████████████████████████████████████████████████████████████| 9/9 [02:49<00:00, 18.84s/it]
  0%|                                                                                            | 0/9 [00:00<?, ?it/s]

train size=26999, accuracy=95.82%, time=32.02[s]
train size=26999 achieved highest accuracy of 95.82% on validation data


 11%|█████████▎                                                                          | 1/9 [00:04<00:34,  4.30s/it]

train size=3000, accuracy=90.29%, time=4.28[s]
train size= 2999 achieved highest accuracy of 90.29% on validation data


 22%|██████████████████▋                                                                 | 2/9 [00:13<00:48,  6.97s/it]

train size=6000, accuracy=92.46%, time=8.83[s]
train size= 5999 achieved highest accuracy of 92.46% on validation data


 33%|████████████████████████████                                                        | 3/9 [00:24<00:52,  8.76s/it]

train size=8999, accuracy=93.54%, time=10.88[s]
train size= 8999 achieved highest accuracy of 93.54% on validation data


 44%|█████████████████████████████████████▎                                              | 4/9 [00:37<00:53, 10.67s/it]

train size=11999, accuracy=93.99%, time=13.56[s]
train size=11999 achieved highest accuracy of 93.99% on validation data


 56%|██████████████████████████████████████████████▋                                     | 5/9 [00:54<00:51, 12.78s/it]

train size=14999, accuracy=94.38%, time=16.50[s]
train size=14999 achieved highest accuracy of 94.38% on validation data


 67%|████████████████████████████████████████████████████████                            | 6/9 [01:13<00:45, 15.11s/it]

train size=17999, accuracy=94.77%, time=19.62[s]
train size=17999 achieved highest accuracy of 94.77% on validation data


 78%|█████████████████████████████████████████████████████████████████▎                  | 7/9 [01:36<00:35, 17.56s/it]

train size=20999, accuracy=95.05%, time=22.59[s]
train size=20999 achieved highest accuracy of 95.05% on validation data


 89%|██████████████████████████████████████████████████████████████████████████▋         | 8/9 [02:04<00:20, 20.77s/it]

train size=23999, accuracy=95.34%, time=27.62[s]
train size=23999 achieved highest accuracy of 95.34% on validation data


100%|████████████████████████████████████████████████████████████████████████████████████| 9/9 [02:37<00:00, 17.47s/it]
  0%|                                                                                            | 0/9 [00:00<?, ?it/s]

train size=26999, accuracy=95.60%, time=33.17[s]
train size=26999 achieved highest accuracy of 95.60% on validation data


 11%|█████████▎                                                                          | 1/9 [00:05<00:40,  5.05s/it]

train size=3000, accuracy=89.93%, time=5.02[s]
train size= 2999 achieved highest accuracy of 89.93% on validation data


 22%|██████████████████▋                                                                 | 2/9 [00:13<00:49,  7.03s/it]

train size=6000, accuracy=92.36%, time=8.40[s]
train size= 5999 achieved highest accuracy of 92.36% on validation data


 33%|████████████████████████████                                                        | 3/9 [00:24<00:53,  8.90s/it]

train size=8999, accuracy=93.31%, time=11.10[s]
train size= 8999 achieved highest accuracy of 93.31% on validation data


 44%|█████████████████████████████████████▎                                              | 4/9 [00:38<00:53, 10.76s/it]

train size=11999, accuracy=93.86%, time=13.59[s]
train size=11999 achieved highest accuracy of 93.86% on validation data


 56%|██████████████████████████████████████████████▋                                     | 5/9 [00:56<00:53, 13.39s/it]

train size=14999, accuracy=94.18%, time=18.05[s]
train size=14999 achieved highest accuracy of 94.18% on validation data


 67%|████████████████████████████████████████████████████████                            | 6/9 [01:18<00:49, 16.56s/it]

train size=17999, accuracy=94.60%, time=22.69[s]
train size=17999 achieved highest accuracy of 94.60% on validation data


 78%|█████████████████████████████████████████████████████████████████▎                  | 7/9 [01:41<00:36, 18.49s/it]

train size=20999, accuracy=94.92%, time=22.43[s]
train size=20999 achieved highest accuracy of 94.92% on validation data


 89%|██████████████████████████████████████████████████████████████████████████▋         | 8/9 [02:06<00:20, 20.72s/it]

train size=23999, accuracy=95.17%, time=25.50[s]
train size=23999 achieved highest accuracy of 95.17% on validation data


100%|████████████████████████████████████████████████████████████████████████████████████| 9/9 [02:35<00:00, 17.32s/it]

train size=26999, accuracy=95.40%, time=28.93[s]
train size=26999 achieved highest accuracy of 95.40% on validation data





In [None]:
plt.clf()
plt.rcParams['font.size'] = 16

# グラフ用
fig, ax1 = plt.subplots(figsize=(12,10))

# x軸は学習データ数にする
x = len(trLabels) * (1-nVals)
plt.xlabel('Number of training samples')

# 1軸に識別率
ax1.plot(x, np.array(accuracies), color = 'r', label = 'accuracies')
# 1軸と2軸の関連づけ
plt.ylabel('Accuracies')
plt.legend()
ax2 = ax1.twinx()
# 2軸に処理時間
ax2.plot(x, np.array(proc_time), 'b', label = 'processing time')
plt.ylabel('Processing Time [s]')

# ファイル保存したいとき
name = 'nn_k=' + str(k) + '_compare.png'
plt.savefig(name)

plt.show()