# MNIST handwritten digits classification with nearest neighbors 

In this notebook, we'll use [nearest-neighbor classifiers](http://scikit-learn.org/stable/modules/neighbors.html#nearest-neighbors-classification) to classify MNIST digits using scikit-learn (version 0.20 or later required).

First, the needed imports. 

In [None]:
%matplotlib inline

import numpy as np
from sklearn.model_selection import train_test_split
from sklearn import neighbors, datasets, __version__
from sklearn.metrics import accuracy_score

import matplotlib.pyplot as plt
import seaborn as sns
sns.set()

from distutils.version import LooseVersion as LV
assert(LV(__version__) >= LV("0.20")), "Version >= 0.20 of sklearn is required."

Then we load the MNIST data. First time we need to download the data, which can take a while.

In [None]:
mnist = datasets.fetch_openml('mnist_784')

X_train, X_test, y_train, y_test = train_test_split(
    mnist['data'], mnist['target'], test_size=10000, shuffle=True)
     
print()
print('MNIST data loaded: train:',len(X_train),'test:',len(X_test))
print('X_train:', X_train.shape)
print('y_train:', y_train.shape)
print('X_test', X_test.shape)
print('y_test', y_test.shape)

The training data (`X_train`) is a matrix of size (60000, 784), i.e. it consists of 60000 digits expressed as 784 sized vectors (28x28 images flattened to 1D). `y_train` is a 60000-dimensional vector containing the correct classes ("0", "1", ..., "9") for each training digit.

Let's take a closer look. Here are the first 10 training digits:

In [None]:
pltsize=1
plt.figure(figsize=(10*pltsize, pltsize))

for i in range(10):
    plt.subplot(1,10,i+1)
    plt.axis('off')
    plt.imshow(X_train[i,:].reshape(28, 28), cmap="gray")
    plt.title('Class: '+y_train[i])

## k-NN (k-nearest neighbors) classifier

![title](imgs/500px-KnnClassification.svg.png)

<br/>

<center><small>Image by Antti Ajanki AnAj (Own work) [<a href="http://www.gnu.org/copyleft/fdl.html">GFDL</a>, <a href="http://creativecommons.org/licenses/by-sa/3.0/">CC-BY-SA-3.0</a> or <a href="http://creativecommons.org/licenses/by-sa/2.5-2.0-1.0">CC BY-SA 2.5-2.0-1.0</a>], <a href="https://commons.wikimedia.org/wiki/File%3AKnnClassification.svg">via Wikimedia Commons</a></small></center>


## 1-NN classifier

### Initialization

Let's create first a 1-NN classifier.  Note that with nearest-neighbor classifiers there is no internal (parameterized) model and therefore no learning required.  Instead, calling the `fit()` function simply stores the samples of the training data in a suitable data structure.

In [None]:
%%time

n_neighbors = 1
clf = neighbors.KNeighborsClassifier(n_neighbors)
clf.fit(X_train, y_train)

### Inference

And try to classify some test samples with it.

In [None]:
%%time

predictions = clf.predict(X_test[:200,:])

We observe that the classifier is rather slow, and classifying the whole test set would take quite some time. What is the reason for this?

The accuracy of the classifier:

In [None]:
print('Predicted', len(predictions), 'digits with accuracy:', accuracy_score(y_test[:len(predictions)], predictions))

## Faster 1-NN classifier

### Initialization

One way to make our 1-NN classifier faster is to use less training data:

In [None]:
%%time

n_neighbors = 1
n_data = 1024
clf_reduced = neighbors.KNeighborsClassifier(n_neighbors)
clf_reduced.fit(X_train[:n_data,:], y_train[:n_data])

### Inference

Now we can use the classifier created with reduced data to classify our whole test set in a reasonable amount of time.

In [None]:
%%time

predictions_reduced = clf_reduced.predict(X_test)

The classification accuracy is however now not as good:

In [None]:
print('Predicted', len(predictions_reduced), 'digits with accuracy:', accuracy_score(y_test, predictions_reduced))

We can also inspect the results in more detail. Let's define and use a helper function to show the wrongly classified test digits.

In [None]:
def show_failures(predictions, trueclass=None, predictedclass=None, maxtoshow=10):
    errors = predictions!=y_test
    print('Showing max', maxtoshow, 'first failures. '
          'The predicted class is shown first and the correct class in parenthesis.')
    ii = 0
    plt.figure(figsize=(maxtoshow, 1))
    for i in range(X_test.shape[0]):
        if ii>=maxtoshow:
            break
        if errors[i]:
            if trueclass is not None and y_test[i] != trueclass:
                continue
            if predictedclass is not None and predictions[i] != predictedclass:
                continue
            plt.subplot(1, maxtoshow, ii+1)
            plt.axis('off')
            plt.imshow(X_test[i,:].reshape(28, 28), cmap="gray")
            plt.title("%s (%s)" % (predictions[i], y_test[i]))
            ii = ii + 1

In [None]:
show_failures(predictions_reduced)

We can observe that the classifier makes rather "easy" mistakes, and there might be room for improvement.

## Model tuning

Try to improve the accuracy of the nearest-neighbor classifier while preserving a reasonable runtime to classify the whole test set. Things to try include using more than one neighbor (with or without weights) or increasing the amount of training data.  See the documentation for [KNeighborsClassifier](https://scikit-learn.org/stable/modules/generated/sklearn.neighbors.KNeighborsClassifier.html#sklearn-neighbors-kneighborsclassifier).

See also http://scikit-learn.org/stable/modules/neighbors.html#nearest-neighbors-classification for more information.