In [1]:
import scipy.io as sio
import numpy as np
import time

# Load and reshape data
mnistData = sio.loadmat('mnistData.mat')
training_images = mnistData['mnist'][0][0][0].reshape(28*28, 60000).astype(np.uint8)
testing_images = mnistData['mnist'][0][0][1].reshape(28*28, 10000).astype(np.uint8)
training_labels = mnistData['mnist'][0][0][2].flatten()
testing_labels = mnistData['mnist'][0][0][3].flatten()

In [2]:

def compute_all_distances(training_images, test_image):
    # L2 distance = sum((a - b)^2) = sum(a^2) + sum(b^2) - 2*a.b
    a2 = np.sum(training_images.astype(np.int32)**2, axis=0)
    b2 = np.sum(test_image.astype(np.int32)**2)
    ab = training_images.T @ test_image.astype(np.int32)
    dists = a2 - 2 * ab + b2
    return dists

In [3]:
def nearest_neighbor_labels(k, training_images, training_labels, test_image):
    distances = compute_all_distances(training_images, test_image)
    nearest_indices = np.argsort(distances)[:k]
    return training_labels[nearest_indices]

In [4]:
import time
# Evaluation
knn_one_correct = 0
knn_three_correct = 0

start_time = time.time()
for i in range(testing_images.shape[1]):
    if i % 100 == 0:
        print(f"Processing image {i}/{testing_images.shape[1]}...")

    nearest = nearest_neighbor_labels(3, training_images, training_labels, testing_images[:, i])

    knn_one = nearest[0]
    knn_three = np.bincount(nearest).argmax()

    if testing_labels[i] == knn_one:
        knn_one_correct += 1
    if testing_labels[i] == knn_three:
        knn_three_correct += 1

# Results
total = testing_images.shape[1]
print("Time taken:", time.time() - start_time)
print("Misclassification rate (k=1):", 1 - knn_one_correct / total)
print("Misclassification rate (k=3):", 1 - knn_three_correct / total)

Processing image 0/10000...
Processing image 100/10000...
Processing image 200/10000...
Processing image 300/10000...
Processing image 400/10000...
Processing image 500/10000...
Processing image 600/10000...
Processing image 700/10000...
Processing image 800/10000...
Processing image 900/10000...
Processing image 1000/10000...
Processing image 1100/10000...
Processing image 1200/10000...
Processing image 1300/10000...
Processing image 1400/10000...
Processing image 1500/10000...
Processing image 1600/10000...
Processing image 1700/10000...
Processing image 1800/10000...
Processing image 1900/10000...
Processing image 2000/10000...
Processing image 2100/10000...
Processing image 2200/10000...
Processing image 2300/10000...
Processing image 2400/10000...
Processing image 2500/10000...
Processing image 2600/10000...
Processing image 2700/10000...
Processing image 2800/10000...
Processing image 2900/10000...
Processing image 3000/10000...
Processing image 3100/10000...
Processing image 320