# MNIST handwritten digits classification with nearest neighbors 

In this notebook, we'll use [nearest-neighbor classifiers](http://scikit-learn.org/stable/modules/neighbors.html#nearest-neighbors-classification) to classify MNIST digits using scikit-learn.

First, the needed imports. 

In [None]:
%matplotlib inline

from time import time
import numpy as np
from sklearn import neighbors
from sklearn.metrics import accuracy_score

import matplotlib.pyplot as plt
import seaborn as sns

Then we load the MNIST data. First time it downloads the data, which can take a while.

In [None]:
from keras.datasets import mnist
(X_train, y_train), (X_test, y_test) = mnist.load_data()

print()
print('MNIST data loaded: train:',len(X_train),'test:',len(X_test))
print('X_train:', X_train.shape)
print('y_train:', y_train.shape)
print('X_test', X_test.shape)
print('y_test', y_test.shape)

The training data (`X_train`) is a 3rd-order tensor of size (60000, 28, 28), i.e. it consists of 60000 images of size 28x28 pixels. `y_train` is a 60000-dimensional vector containing the correct classes ("0", "1", ..., "9") for each training digit.

Let's take a closer look. Here are the first 10 training digits:

In [None]:
pltsize=1
plt.figure(figsize=(10*pltsize, pltsize))

for i in range(10):
    plt.subplot(1,10,i+1)
    plt.axis('off')
    plt.imshow(X_train[i,:,:], cmap="gray")
    plt.title('Class: '+str(y_train[i]))

## 1-NN classifier

Let's create first a 1-NN classifier.  Notice the `reshape(-1,28*28)` function which flattens the 2-D images into 1-D vectors (from 28*28 pixel images to 784-dimensional vectors). 

In [None]:
n_neighbors = 1
clf = neighbors.KNeighborsClassifier(n_neighbors)
clf.fit(X_train.reshape(-1,28*28), y_train)

And try to classify some test samples with it.

In [None]:
t0 = time()
predictions = clf.predict(X_test[:100,:,:].reshape(-1,28*28))
print('Time elapsed: %.2fs' % (time()-t0))

We observe that the classifier is rather slow, and classifying the whole test set would take quite some time. What is the reason for this?

The accuracy of the classifier:

In [None]:
print('Predicted', len(predictions), 'digits with accuracy:', accuracy_score(y_test[:100], predictions))

## Faster 1-NN classifier

One way to make our 1-NN classifier faster is to use less training data:

In [None]:
n_neighbors = 1
clf_reduced = neighbors.KNeighborsClassifier(n_neighbors)
clf_reduced.fit(X_train[:1024,:,:].reshape(-1,28*28), y_train[:1024])

Now we can use the classifier created with reduced data to classify our whole test set in a reasonable amount of time.

In [None]:
t0 = time()
predictions_reduced = clf_reduced.predict(X_test.reshape(-1,28*28))
print('Time elapsed: %.2fs' % (time()-t0))

The classification accuracy is however now not as good:

In [None]:
print('Predicted', len(predictions_reduced), 'digits with accuracy:', accuracy_score(y_test, predictions_reduced))

We can also inspect the results in more detail. Let's see some test digits the model got wrong.

In [None]:
maxtoshow = 10
errors = predictions_reduced!=y_test
print('Showing', maxtoshow, 'first failures.  The predicted class is shown first and the correct class in parenthesis.')
ii = 0
plt.figure(figsize=(maxtoshow*pltsize, pltsize))
for i in range(X_test.shape[0]):
    if ii>=maxtoshow:
        break
    if errors[i]:
        plt.subplot(1, maxtoshow, ii+1)
        plt.axis('off')
        plt.imshow(X_test[i,:,:], cmap="gray")
        plt.title("%d (%d)" % (predictions_reduced[i], y_test[i]))
        ii = ii + 1

We can observe that the classifier makes rather "easy" mistakes, and there seems to be room for improvement.

## Model tuning

Try to improve the accuracy of the nearest-neighbor classifier while preserving a reasonable runtime to classify the whole test set.  Possible things to try include using more than one neighbor or increasing the amount data.  See also  http://scikit-learn.org/stable/modules/neighbors.html#nearest-neighbors-classification for more information.