In [77]:
import time
from torchvision import datasets, transforms
import numpy as np
from tqdm import tqdm

In [78]:
def load_data(preprocess = False):
    pipline = transforms.ToTensor()
    dataset_train = datasets.MNIST(root='Mnist', train=True, transform=pipline, download=True)
    dataset_test = datasets.MNIST(root='Mnist', train=False, transform=pipline, download=True)
    imgs_train = dataset_train.data.numpy()
    imgs_train = np.reshape(imgs_train, (imgs_train.shape[0], -1))  # flatten each image

    imgs_test = dataset_test.data.numpy()
    imgs_test = np.reshape(imgs_test, (imgs_test.shape[0], -1))  # flatten each image
    if preprocess:
        # for img in tqdm(imgs_train, desc='Train Data Preprocess'):
        #     for r in range(784):
        #         img[r, ]=255 if img[r, ]>= 128 else 0
        # for img in tqdm(imgs_test, desc='Test Data Preprocess'):
        #     for r in range(784):
        #         img[r, ]=255 if img[r, ]>= 128 else 0
        imgs_train[imgs_train < 127] = 0
        imgs_train[imgs_train != 0] = 255
        imgs_test[imgs_test < 127] = 0
        imgs_test[imgs_test != 0] = 255
    labl_train = dataset_train.targets.numpy()
    labl_test = dataset_test.targets.numpy()
    return imgs_train, labl_train, imgs_test, labl_test

In [79]:
def get_distance_list(k, test_image, train_images, train_labels):
    res = []
    for i in range(len(train_images)):
        dist = np.linalg.norm(test_image - train_images[i])
        # print(train_labels)
        res.append([dist, train_labels[i]])
    res = sorted(res, key=(lambda t: t[0]))
    return res[:k]

In [80]:
def KNNClassify(k, num_test, preprocess):
    result = []
    cnt = 1
    imgs_train, labl_train, imgs_test, labl_test = load_data(preprocess)
    imgs_test, labl_test = imgs_test[:num_test], labl_test[:num_test]
    # imgs_train, labl_train = imgs_train[:10], labl_train[10]
    # print('imgs_train', imgs_train.shape)
    # print('labl_train', labl_train.shape)
    # print('imgs_test', imgs_test.shape)
    # print('labl_test', labl_test.shape)
    for i in tqdm(range(num_test)):
        # print(f'The {cnt}th image test starting...')
        # cnt += 1
        dist_lst = get_distance_list(k, imgs_test[i], imgs_train, labl_train)
        k_labels = []

        for dist, label in dist_lst:
            k_labels.append(label)
        result.append(max(k_labels, key=k_labels.count))

    print('Prediction->',result)
    return result

In [81]:
def evaluate():
    k = 20
    num_test = 20
    num_correct = 0
    imgs_train, labl_train, imgs_test, labl_test = load_data()
    # print('imgs_train', imgs_train.shape)
    # print('labl_train', labl_train.shape)
    # print('imgs_test', imgs_test.shape)
    # print('labl_test', labl_test.shape)
    start_time = time.time()
    outputlabels = KNNClassify(k, num_test, preprocess=True)
    for idx, pred_label in enumerate(outputlabels):
        num_correct += (pred_label == labl_test[idx])
    accuracy = num_correct / num_test
    print("---classification accuracy for knn on mnist: %s ---" % accuracy)
    print("---execution time: %s seconds ---" % (time.time() - start_time))

In [82]:
evaluate()

100%|██████████| 20/20 [00:05<00:00,  3.89it/s]

Prediction-> [7, 2, 1, 0, 4, 1, 4, 9, 0, 9, 0, 6, 9, 0, 1, 5, 9, 7, 3, 4]
---classification accuracy for knn on mnist: 0.95 ---
---execution time: 5.315220355987549 seconds ---



