# K-Nearest Neighbour On MNIST Data

The objective of this dataset is to accurately classify handwritten digits from 0 to 9. Instead of using the full MNIST dataset, which contains 60,000 training images and 10,000 testing images, we will work with a smaller subset provided by the scikit-learn library. This subset includes 1,797 digit images, which we will divide into training, validation, and testing sets.
Each image is originally an 8 x 8 grayscale image, but scikit-learn converts it into a flattened list.

In [2]:
# import the necessary packages
from __future__ import print_function
from sklearn.neighbors import KNeighborsClassifier
from sklearn.metrics import classification_report
from sklearn import datasets
from skimage import exposure
import numpy as np
import tensorflow as tf
import cv2
import imutils
import sklearn
from sklearn.model_selection import train_test_split
import matplotlib.pyplot as plt
import warnings
warnings.filterwarnings('ignore')

In [3]:
#Load Mnist data digits 
mnist = datasets.load_digits()

In [4]:
# take the MNIST data and construct the training and testing split, using 75% of the
# data for training and 25% for testing
(trainData, testData, trainLabels, testLabels) = train_test_split(np.array(mnist.data),
	mnist.target, test_size=0.25, random_state=42)

In [5]:
# now, let's take 10% of the training data and use that for validation
(trainData, valData, trainLabels, valLabels) = train_test_split(trainData, trainLabels,
	test_size=0.1, random_state=84)

In [6]:
# show the sizes of each data split
print("training data points: {}".format(len(trainLabels)))
print("validation data points: {}".format(len(valLabels)))
print("testing data points: {}".format(len(testLabels)))

training data points: 1212
validation data points: 135
testing data points: 450


Now that we have our data splits taken care of, let’s train our classifier and find the optimal value of k.


In [7]:
# initialize the values of k for our k-Nearest Neighbor classifier along with the
# list of accuracies for each value of k
kVals = range(1, 30, 2)
accuracies = []
 

In [8]:
# loop over various values of `k` for the k-Nearest Neighbor classifier
for k in range(1, 30, 2):
# train the k-Nearest Neighbor classifier with the current value of `k`
    model = KNeighborsClassifier(n_neighbors=k)
    model.fit(trainData, trainLabels)
#After our model is trained, we need to evaluate it using our validation data    
# evaluate the model and update the accuracies list
    score = model.score(valData, valLabels)
    print("k=%d, accuracy=%.2f%%" % (k, score * 100))
    accuracies.append(score)

k=1, accuracy=99.26%
k=3, accuracy=99.26%
k=5, accuracy=99.26%
k=7, accuracy=99.26%
k=9, accuracy=99.26%
k=11, accuracy=99.26%
k=13, accuracy=99.26%
k=15, accuracy=99.26%
k=17, accuracy=98.52%
k=19, accuracy=98.52%
k=21, accuracy=97.78%
k=23, accuracy=97.04%
k=25, accuracy=97.78%
k=27, accuracy=97.04%
k=29, accuracy=97.04%


In [9]:
# find the value of k that has the largest accuracy
i = int(np.argmax(accuracies))
print("k=%d achieved highest accuracy of %.2f%% on validation data" % (kVals[i],
accuracies[i] * 100))

k=1 achieved highest accuracy of 99.26% on validation data


Although the accuracy for \( k = 1 \) through \( k = 15 \) remained consistent, using just one neighbor significantly improves efficiency. Therefore, we will use \( k = 1 \) for training and evaluating our classifier on the final test data.

In [10]:
# re-train our classifier using the best k value and predict the labels of the
# test data
model = KNeighborsClassifier(n_neighbors=kVals[i])
model.fit(trainData, trainLabels)
predictions = model.predict(testData)

In [11]:
# show a final classification report demonstrating the accuracy of the classifier
# for each of the digits
print("EVALUATION ON TESTING DATA")
print(classification_report(testLabels, predictions))

EVALUATION ON TESTING DATA
              precision    recall  f1-score   support

           0       1.00      1.00      1.00        43
           1       0.95      1.00      0.97        37
           2       1.00      1.00      1.00        38
           3       0.98      0.98      0.98        46
           4       0.98      0.98      0.98        55
           5       0.98      1.00      0.99        59
           6       1.00      1.00      1.00        45
           7       1.00      0.98      0.99        41
           8       0.97      0.95      0.96        38
           9       0.96      0.94      0.95        48

    accuracy                           0.98       450
   macro avg       0.98      0.98      0.98       450
weighted avg       0.98      0.98      0.98       450



Achieving 98% accuracy is impressive! Additionally, digits 0, 2, 6, and 7 are classified correctly 100% of the time. The digit 1 has the lowest classification accuracy, at 95%.
Achieving high accuracy on the MNIST dataset doesn't mean handwritten digit recognition is "solved." Despite MNIST being a standard benchmark, its images are heavily pre-processed—cropped, thresholded, and centered—which doesn’t reflect real-world conditions. In practice, real-world datasets are often less clean and require feature extraction beyond raw pixel intensities. Nonetheless, this exercise demonstrates how Euclidean distance can yield high accuracy with well-pre-processed data.

In [None]:


# Loop over a few random digits
for i in np.random.randint(0, high=len(testLabels), size=5):
    # Grab the image and classify it
    image = testData[i]
    prediction = model.predict(image.reshape(1, -1))[0]

    # Convert the image for an 64-dim array to an 8 x 8 image compatible with OpenCV,
    # then resize it to 32 x 32 pixels so we can see it better
    image = image.reshape((8, 8)).astype("uint8")
    image = exposure.rescale_intensity(image, out_range=(0, 255))

    # Resize the image to 32 x 32 pixels
    image = imutils.resize(image, width=32, inter=cv2.INTER_CUBIC)



Let’s conclude this code example by reviewing some individual predictions made by our k-NN classifier.

In [None]:
import cv2

# Show the prediction
print("I think that digit is: {}".format(prediction))

# Display the image
cv2.imshow("Image", image)

# Wait indefinitely until a key is pressed
cv2.waitKey(0)

# Close all OpenCV windows
cv2.destroyAllWindows()

